Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.
—
NumPyro provides comprehensive diagnostic utilities for assessing MCMC convergence, computing effective sample sizes, and summarizing posterior distributions. These tools are essential for validating the quality of Bayesian inference results and ensuring reliable posterior estimates.
Functions for assessing MCMC chain convergence and mixing.
def gelman_rubin(x: NDArray) -> NDArray:
"""
Compute Gelman-Rubin convergence diagnostic (R-hat statistic).
Assesses convergence by comparing within-chain and between-chain variances.
Values close to 1.0 indicate convergence; values > 1.1 suggest lack of convergence.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...) or
(num_chains, num_samples)
Returns:
R-hat statistic for each parameter. Values near 1.0 indicate convergence.
Usage:
# Get samples from MCMC
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
mcmc.run(rng_key, data)
samples = mcmc.get_samples(group_by_chain=True)
# Compute R-hat for each parameter
rhat = numpyro.diagnostics.gelman_rubin(samples['theta'])
print(f"R-hat for theta: {rhat}")
# Check convergence (should be < 1.1)
converged = jnp.all(rhat < 1.1)
"""
def split_gelman_rubin(x: NDArray) -> NDArray:
"""
Compute split Gelman-Rubin diagnostic (split R-hat).
More robust version of R-hat that splits each chain in half to increase
the number of chains for better convergence assessment.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...)
Returns:
Split R-hat statistic for each parameter
Usage:
# More robust convergence assessment
split_rhat = numpyro.diagnostics.split_gelman_rubin(samples['theta'])
print(f"Split R-hat: {split_rhat}")
# This is generally more reliable than regular R-hat
converged = jnp.all(split_rhat < 1.1)
"""
def effective_sample_size(x: NDArray) -> NDArray:
"""
Compute effective sample size (ESS) for MCMC chains.
ESS estimates the number of independent samples that would provide
the same statistical power as the correlated MCMC samples.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...)
Returns:
Effective sample size for each parameter
Usage:
# Assess sampling efficiency
ess = numpyro.diagnostics.effective_sample_size(samples['theta'])
print(f"Effective sample size: {ess}")
# Rule of thumb: ESS should be > 100 for reliable estimates
# ESS > 400 is generally considered good
total_samples = samples['theta'].shape[0] * samples['theta'].shape[1]
efficiency = ess / total_samples
print(f"Sampling efficiency: {efficiency:.2%}")
"""Functions for analyzing temporal correlations in MCMC samples.
def autocorrelation(x: NDArray) -> NDArray:
"""
Compute autocorrelation function for MCMC chains.
Measures how correlated a time series is with lagged versions of itself.
Useful for understanding the temporal structure of MCMC samples.
Args:
x: MCMC samples with shape (num_samples,) or (num_samples, num_features)
Returns:
Autocorrelation function values for different lags
Usage:
# Analyze autocorrelation structure
# First flatten chains if multiple chains
flat_samples = samples['theta'].reshape(-1) # (total_samples,)
autocorr = numpyro.diagnostics.autocorrelation(flat_samples)
# Plot autocorrelation to assess mixing
import matplotlib.pyplot as plt
plt.plot(autocorr[:100]) # First 100 lags
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.title('MCMC Autocorrelation')
"""
def autocovariance(x: NDArray) -> NDArray:
"""
Compute autocovariance function for MCMC chains.
Similar to autocorrelation but without normalization, preserving
the actual variance scale of the correlations.
Args:
x: MCMC samples with shape (num_samples,) or (num_samples, num_features)
Returns:
Autocovariance function values for different lags
Usage:
# Compute autocovariance for variance analysis
flat_samples = samples['theta'].reshape(-1)
autocov = numpyro.diagnostics.autocovariance(flat_samples)
# First value is the variance
variance = autocov[0]
print(f"Sample variance: {variance}")
"""Functions for summarizing posterior distributions.
def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray:
"""
Compute Highest Posterior Density Interval (HPDI).
HPDI is the shortest interval that contains the specified probability mass.
More informative than equal-tailed intervals for skewed distributions.
Args:
x: Posterior samples
prob: Probability mass to include in interval (default: 0.9)
axis: Axis along which to compute intervals (default: 0)
Returns:
Array with shape (..., 2) containing lower and upper bounds
Usage:
# 90% highest posterior density interval
hpdi_90 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.9)
print(f"90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")
# 95% HPDI for comparison
hpdi_95 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.95)
print(f"95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")
# For multivariate parameters
multivar_hpdi = numpyro.diagnostics.hpdi(samples['weights'], prob=0.9)
# Shape: (num_parameters, 2)
"""
def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None:
"""
Print comprehensive summary statistics for posterior samples.
Provides mean, standard deviation, HPDI, effective sample size, and R-hat
for all parameters in a formatted table.
Args:
samples: Dictionary of posterior samples from MCMC
prob: Probability for HPDI computation (default: 0.9)
group_by_chain: Whether samples are grouped by chain
Usage:
# Get samples and print summary
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
mcmc.run(rng_key, data)
samples = mcmc.get_samples(group_by_chain=True)
# Print comprehensive summary
numpyro.diagnostics.print_summary(samples, prob=0.95)
# Output format:
# mean std median 90.0% n_eff r_hat
# theta 1.23 0.45 1.20 [0.56, 1.91] 892.5 1.002
# sigma 2.34 0.12 2.33 [2.14, 2.56] 1205.2 1.001
"""
def summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> dict:
"""
Compute summary statistics for posterior samples without printing.
Args:
samples: Dictionary of posterior samples
prob: Probability for HPDI computation
group_by_chain: Whether samples are grouped by chain
Returns:
Dictionary containing summary statistics for each parameter
Usage:
# Get summary as dictionary for further processing
summary_stats = numpyro.diagnostics.summary(samples, prob=0.95)
for param_name, stats in summary_stats.items():
print(f"{param_name}:")
print(f" Mean: {stats['mean']:.3f}")
print(f" Std: {stats['std']:.3f}")
print(f" R-hat: {stats['r_hat']:.3f}")
print(f" ESS: {stats['n_eff']:.1f}")
"""Functions for diagnosing model-specific issues.
def split_by_chain(x: NDArray) -> NDArray:
"""
Split samples by chain for chain-specific analysis.
Args:
x: Samples with shape (num_chains, num_samples, ...)
Returns:
List of arrays, one per chain
Usage:
# Analyze chains separately
chain_samples = numpyro.diagnostics.split_by_chain(samples['theta'])
for i, chain in enumerate(chain_samples):
mean_i = jnp.mean(chain)
print(f"Chain {i} mean: {mean_i:.3f}")
"""
def potential_scale_reduction(x: NDArray, split_chains: bool = True) -> NDArray:
"""
Compute potential scale reduction factor (PSRF).
Also known as R-hat, measures the ratio of the average variance of samples
within each chain to the variance of the pooled samples across chains.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...)
split_chains: Whether to split chains for more robust estimates
Returns:
PSRF values for each parameter
Usage:
# Alternative interface to gelman_rubin
psrf = numpyro.diagnostics.potential_scale_reduction(samples['theta'])
print(f"PSRF: {psrf}")
"""
def rank_plot_data(samples: dict, param_names: Optional[list] = None) -> dict:
"""
Prepare data for rank plots (for external plotting).
Rank plots help visualize chain mixing by showing the distribution
of ranks of samples from different chains.
Args:
samples: Dictionary of MCMC samples
param_names: List of parameter names to include
Returns:
Dictionary with rank data for plotting
Usage:
# Prepare data for rank plots
rank_data = numpyro.diagnostics.rank_plot_data(samples, ['theta', 'sigma'])
# Use with external plotting library
import matplotlib.pyplot as plt
for param, ranks in rank_data.items():
plt.figure()
for chain_ranks in ranks:
plt.hist(chain_ranks, alpha=0.5, bins=50)
plt.title(f"Rank plot for {param}")
"""Helper functions for diagnostic computations.
def within_chain_variance(x: NDArray) -> NDArray:
"""
Compute within-chain variance for R-hat calculation.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...)
Returns:
Within-chain variance for each parameter
"""
def between_chain_variance(x: NDArray) -> NDArray:
"""
Compute between-chain variance for R-hat calculation.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...)
Returns:
Between-chain variance for each parameter
"""
def integrated_autocorr_time(x: NDArray, c: float = 5.0,
tol: float = 50.0, quiet: bool = False) -> float:
"""
Compute integrated autocorrelation time.
Estimates the correlation time by integrating the autocorrelation function
until it becomes unreliable.
Args:
x: Time series data
c: Window size multiplier for automatic windowing
tol: Tolerance for unreliable estimates
quiet: Whether to suppress warnings
Returns:
Integrated autocorrelation time
Usage:
# Estimate correlation time
flat_samples = samples['theta'].reshape(-1)
tau = numpyro.diagnostics.integrated_autocorr_time(flat_samples)
print(f"Autocorrelation time: {tau:.2f}")
# Rule of thumb: need at least 50*tau samples for reliable estimates
min_samples = 50 * tau
actual_samples = len(flat_samples)
print(f"Recommended samples: {min_samples:.0f}, Actual: {actual_samples}")
"""
def compute_chain_statistics(x: NDArray) -> dict:
"""
Compute comprehensive statistics for individual chains.
Args:
x: MCMC samples with shape (num_chains, num_samples, ...)
Returns:
Dictionary with statistics for each chain
Usage:
# Analyze individual chain performance
chain_stats = numpyro.diagnostics.compute_chain_statistics(samples['theta'])
for chain_id, stats in chain_stats.items():
print(f"Chain {chain_id}:")
print(f" Mean: {stats['mean']:.3f}")
print(f" Variance: {stats['var']:.3f}")
print(f" ESS: {stats['ess']:.1f}")
"""import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpyro.diagnostics as diagnostics
import jax.numpy as jnp
from jax import random
# Comprehensive diagnostic workflow
def diagnostic_workflow_example():
# Define a simple model
def model(x, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
beta = numpyro.sample("beta", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = alpha + beta * x
with numpyro.plate("data", len(x)):
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
# Generate synthetic data
key = random.PRNGKey(0)
n_data = 100
x = jnp.linspace(0, 1, n_data)
true_alpha, true_beta, true_sigma = 1.0, 2.0, 0.1
y = true_alpha + true_beta * x + true_sigma * random.normal(key, (n_data,))
# Run MCMC with multiple chains
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(random.PRNGKey(1), x, y)
# Get samples grouped by chain for diagnostics
samples = mcmc.get_samples(group_by_chain=True)
print("=== MCMC Diagnostic Report ===")
# 1. Print comprehensive summary
print("\\n1. Summary Statistics:")
diagnostics.print_summary(samples, prob=0.95)
# 2. Check convergence with R-hat
print("\\n2. Convergence Diagnostics:")
for param_name, param_samples in samples.items():
rhat = diagnostics.gelman_rubin(param_samples)
split_rhat = diagnostics.split_gelman_rubin(param_samples)
print(f"{param_name}:")
print(f" R-hat: {rhat:.4f}")
print(f" Split R-hat: {split_rhat:.4f}")
print(f" Converged (R-hat < 1.1): {rhat < 1.1}")
# 3. Assess sampling efficiency
print("\\n3. Sampling Efficiency:")
total_samples = samples['alpha'].shape[0] * samples['alpha'].shape[1]
for param_name, param_samples in samples.items():
ess = diagnostics.effective_sample_size(param_samples)
efficiency = ess / total_samples
print(f"{param_name}:")
print(f" ESS: {ess:.1f}")
print(f" Efficiency: {efficiency:.2%}")
print(f" Good ESS (>400): {ess > 400}")
# 4. Posterior intervals
print("\\n4. Posterior Intervals:")
flat_samples = {k: v.reshape(-1) for k, v in samples.items()}
for param_name, param_samples in flat_samples.items():
hpdi_90 = diagnostics.hpdi(param_samples, prob=0.9)
hpdi_95 = diagnostics.hpdi(param_samples, prob=0.95)
print(f"{param_name}:")
print(f" 90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")
print(f" 95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")
# 5. Autocorrelation analysis
print("\\n5. Autocorrelation Analysis:")
for param_name, param_samples in flat_samples.items():
autocorr = diagnostics.autocorrelation(param_samples)
tau = diagnostics.integrated_autocorr_time(param_samples, quiet=True)
print(f"{param_name}:")
print(f" Autocorr time: {tau:.2f}")
print(f" Recommended min samples: {50 * tau:.0f}")
print(f" Actual samples: {len(param_samples)}")
return samples
# Chain-specific diagnostics
def chain_analysis_example():
# Assume we have samples from previous example
# samples = ... (from MCMC run)
# Analyze individual chains
param_samples = samples['alpha'] # Shape: (num_chains, num_samples)
print("=== Individual Chain Analysis ===")
# Split by chain and analyze separately
chain_samples = diagnostics.split_by_chain(param_samples)
for i, chain in enumerate(chain_samples):
mean_i = jnp.mean(chain)
std_i = jnp.std(chain)
autocorr_i = diagnostics.autocorrelation(chain)
print(f"\\nChain {i}:")
print(f" Mean: {mean_i:.4f}")
print(f" Std: {std_i:.4f}")
print(f" First 5 autocorr values: {autocorr_i[:5]}")
# Compare within vs between chain variance
within_var = diagnostics.within_chain_variance(param_samples)
between_var = diagnostics.between_chain_variance(param_samples)
print(f"\\nVariance Analysis:")
print(f" Within-chain variance: {within_var:.6f}")
print(f" Between-chain variance: {between_var:.6f}")
print(f" Ratio (should be ~1): {between_var / within_var:.4f}")
# Diagnostic-driven sampling strategy
def adaptive_sampling_example():
"""Example of using diagnostics to determine sampling requirements."""
def model():
# Deliberately create a challenging posterior
x = numpyro.sample("x", dist.Normal(0, 1))
y = numpyro.sample("y", dist.Normal(x**2, 0.1)) # Non-linear relationship
# Start with small number of samples
initial_samples = 500
target_ess = 400
max_iterations = 5
for iteration in range(max_iterations):
print(f"\\n--- Iteration {iteration + 1} ---")
# Run MCMC
mcmc = MCMC(NUTS(model),
num_warmup=initial_samples,
num_samples=initial_samples,
num_chains=4)
mcmc.run(random.PRNGKey(iteration))
samples = mcmc.get_samples(group_by_chain=True)
# Check diagnostics
rhat = diagnostics.gelman_rubin(samples['x'])
ess = diagnostics.effective_sample_size(samples['x'])
print(f"Current samples per chain: {initial_samples}")
print(f"R-hat: {rhat:.4f}")
print(f"ESS: {ess:.1f}")
# Check if we meet convergence criteria
converged = rhat < 1.1
sufficient_ess = ess > target_ess
if converged and sufficient_ess:
print(f"✓ Convergence achieved!")
break
elif not converged:
print(f"✗ Poor convergence (R-hat = {rhat:.4f})")
initial_samples = int(initial_samples * 1.5) # Increase samples
elif not sufficient_ess:
print(f"✗ Insufficient ESS ({ess:.1f} < {target_ess})")
initial_samples = int(initial_samples * 1.2) # Modest increase
return samplesfrom typing import Optional, Union, Dict, Any, List
from jax import Array
import jax.numpy as jnp
NDArray = jnp.ndarray
ArrayLike = Union[Array, NDArray, float, int]
Samples = Dict[str, NDArray]
class DiagnosticResult:
"""Base class for diagnostic results."""
pass
class SummaryStats:
"""Summary statistics for a parameter."""
mean: float
std: float
median: float
mad: float # Median absolute deviation
hpdi_lower: float
hpdi_upper: float
n_eff: float # Effective sample size
r_hat: float # R-hat statistic
class ConvergenceDiagnostic:
"""Convergence diagnostic results."""
r_hat: NDArray
split_r_hat: NDArray
converged: bool
potential_scale_reduction: NDArray
class EfficiencyDiagnostic:
"""Sampling efficiency diagnostic results."""
effective_sample_size: NDArray
autocorrelation_time: NDArray
efficiency_ratio: NDArray
class AutocorrelationResult:
"""Autocorrelation analysis results."""
autocorr: NDArray
autocov: NDArray
integrated_time: float
class ChainStatistics:
"""Statistics for individual MCMC chains."""
chain_id: int
mean: NDArray
variance: NDArray
effective_sample_size: NDArray
autocorrelation_time: float
# Function type signatures
ConvergenceFunction = Callable[[NDArray], NDArray]
SummaryFunction = Callable[[NDArray], Dict[str, Any]]
DiagnosticFunction = Callable[[NDArray], DiagnosticResult]Install with Tessl CLI
npx tessl i tessl/pypi-numpyro