CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

diagnostics.mddocs/

Diagnostics

NumPyro provides comprehensive diagnostic utilities for assessing MCMC convergence, computing effective sample sizes, and summarizing posterior distributions. These tools are essential for validating the quality of Bayesian inference results and ensuring reliable posterior estimates.

Capabilities

Convergence Diagnostics

Functions for assessing MCMC chain convergence and mixing.

def gelman_rubin(x: NDArray) -> NDArray:
    """
    Compute Gelman-Rubin convergence diagnostic (R-hat statistic).
    
    Assesses convergence by comparing within-chain and between-chain variances.
    Values close to 1.0 indicate convergence; values > 1.1 suggest lack of convergence.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...) or 
          (num_chains, num_samples)
          
    Returns:
        R-hat statistic for each parameter. Values near 1.0 indicate convergence.
        
    Usage:
        # Get samples from MCMC
        mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
        mcmc.run(rng_key, data)
        samples = mcmc.get_samples(group_by_chain=True)
        
        # Compute R-hat for each parameter
        rhat = numpyro.diagnostics.gelman_rubin(samples['theta'])
        print(f"R-hat for theta: {rhat}")
        
        # Check convergence (should be < 1.1)
        converged = jnp.all(rhat < 1.1)
    """

def split_gelman_rubin(x: NDArray) -> NDArray:
    """
    Compute split Gelman-Rubin diagnostic (split R-hat).
    
    More robust version of R-hat that splits each chain in half to increase 
    the number of chains for better convergence assessment.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...)
        
    Returns:
        Split R-hat statistic for each parameter
        
    Usage:
        # More robust convergence assessment
        split_rhat = numpyro.diagnostics.split_gelman_rubin(samples['theta'])
        print(f"Split R-hat: {split_rhat}")
        
        # This is generally more reliable than regular R-hat
        converged = jnp.all(split_rhat < 1.1)
    """

def effective_sample_size(x: NDArray) -> NDArray:
    """
    Compute effective sample size (ESS) for MCMC chains.
    
    ESS estimates the number of independent samples that would provide 
    the same statistical power as the correlated MCMC samples.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...)
        
    Returns:
        Effective sample size for each parameter
        
    Usage:
        # Assess sampling efficiency
        ess = numpyro.diagnostics.effective_sample_size(samples['theta'])
        print(f"Effective sample size: {ess}")
        
        # Rule of thumb: ESS should be > 100 for reliable estimates
        # ESS > 400 is generally considered good
        total_samples = samples['theta'].shape[0] * samples['theta'].shape[1]
        efficiency = ess / total_samples
        print(f"Sampling efficiency: {efficiency:.2%}")
    """

Autocorrelation Analysis

Functions for analyzing temporal correlations in MCMC samples.

def autocorrelation(x: NDArray) -> NDArray:
    """
    Compute autocorrelation function for MCMC chains.
    
    Measures how correlated a time series is with lagged versions of itself.
    Useful for understanding the temporal structure of MCMC samples.
    
    Args:
        x: MCMC samples with shape (num_samples,) or (num_samples, num_features)
        
    Returns:
        Autocorrelation function values for different lags
        
    Usage:
        # Analyze autocorrelation structure
        # First flatten chains if multiple chains
        flat_samples = samples['theta'].reshape(-1)  # (total_samples,)
        autocorr = numpyro.diagnostics.autocorrelation(flat_samples)
        
        # Plot autocorrelation to assess mixing
        import matplotlib.pyplot as plt
        plt.plot(autocorr[:100])  # First 100 lags
        plt.xlabel('Lag')
        plt.ylabel('Autocorrelation')
        plt.title('MCMC Autocorrelation')
    """

def autocovariance(x: NDArray) -> NDArray:
    """
    Compute autocovariance function for MCMC chains.
    
    Similar to autocorrelation but without normalization, preserving
    the actual variance scale of the correlations.
    
    Args:
        x: MCMC samples with shape (num_samples,) or (num_samples, num_features)
        
    Returns:
        Autocovariance function values for different lags
        
    Usage:
        # Compute autocovariance for variance analysis
        flat_samples = samples['theta'].reshape(-1)
        autocov = numpyro.diagnostics.autocovariance(flat_samples)
        
        # First value is the variance
        variance = autocov[0]
        print(f"Sample variance: {variance}")
    """

Posterior Summary Statistics

Functions for summarizing posterior distributions.

def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray:
    """
    Compute Highest Posterior Density Interval (HPDI).
    
    HPDI is the shortest interval that contains the specified probability mass.
    More informative than equal-tailed intervals for skewed distributions.
    
    Args:
        x: Posterior samples
        prob: Probability mass to include in interval (default: 0.9)
        axis: Axis along which to compute intervals (default: 0)
        
    Returns:
        Array with shape (..., 2) containing lower and upper bounds
        
    Usage:
        # 90% highest posterior density interval
        hpdi_90 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.9)
        print(f"90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")
        
        # 95% HPDI for comparison
        hpdi_95 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.95)
        print(f"95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")
        
        # For multivariate parameters
        multivar_hpdi = numpyro.diagnostics.hpdi(samples['weights'], prob=0.9)
        # Shape: (num_parameters, 2)
    """

def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None:
    """
    Print comprehensive summary statistics for posterior samples.
    
    Provides mean, standard deviation, HPDI, effective sample size, and R-hat
    for all parameters in a formatted table.
    
    Args:
        samples: Dictionary of posterior samples from MCMC
        prob: Probability for HPDI computation (default: 0.9)
        group_by_chain: Whether samples are grouped by chain
        
    Usage:
        # Get samples and print summary
        mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
        mcmc.run(rng_key, data)
        samples = mcmc.get_samples(group_by_chain=True)
        
        # Print comprehensive summary
        numpyro.diagnostics.print_summary(samples, prob=0.95)
        
        # Output format:
        #                mean       std    median     90.0%     n_eff     r_hat
        #     theta      1.23      0.45      1.20  [0.56, 1.91]   892.5    1.002
        #     sigma      2.34      0.12      2.33  [2.14, 2.56]   1205.2   1.001
    """

def summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> dict:
    """
    Compute summary statistics for posterior samples without printing.
    
    Args:
        samples: Dictionary of posterior samples
        prob: Probability for HPDI computation
        group_by_chain: Whether samples are grouped by chain
        
    Returns:
        Dictionary containing summary statistics for each parameter
        
    Usage:
        # Get summary as dictionary for further processing
        summary_stats = numpyro.diagnostics.summary(samples, prob=0.95)
        
        for param_name, stats in summary_stats.items():
            print(f"{param_name}:")
            print(f"  Mean: {stats['mean']:.3f}")
            print(f"  Std: {stats['std']:.3f}")
            print(f"  R-hat: {stats['r_hat']:.3f}")
            print(f"  ESS: {stats['n_eff']:.1f}")
    """

Model Diagnostics

Functions for diagnosing model-specific issues.

def split_by_chain(x: NDArray) -> NDArray:
    """
    Split samples by chain for chain-specific analysis.
    
    Args:
        x: Samples with shape (num_chains, num_samples, ...)
        
    Returns:
        List of arrays, one per chain
        
    Usage:
        # Analyze chains separately
        chain_samples = numpyro.diagnostics.split_by_chain(samples['theta'])
        
        for i, chain in enumerate(chain_samples):
            mean_i = jnp.mean(chain)
            print(f"Chain {i} mean: {mean_i:.3f}")
    """

def potential_scale_reduction(x: NDArray, split_chains: bool = True) -> NDArray:
    """
    Compute potential scale reduction factor (PSRF).
    
    Also known as R-hat, measures the ratio of the average variance of samples
    within each chain to the variance of the pooled samples across chains.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...)
        split_chains: Whether to split chains for more robust estimates
        
    Returns:
        PSRF values for each parameter
        
    Usage:
        # Alternative interface to gelman_rubin
        psrf = numpyro.diagnostics.potential_scale_reduction(samples['theta'])
        print(f"PSRF: {psrf}")
    """

def rank_plot_data(samples: dict, param_names: Optional[list] = None) -> dict:
    """
    Prepare data for rank plots (for external plotting).
    
    Rank plots help visualize chain mixing by showing the distribution
    of ranks of samples from different chains.
    
    Args:
        samples: Dictionary of MCMC samples
        param_names: List of parameter names to include
        
    Returns:
        Dictionary with rank data for plotting
        
    Usage:
        # Prepare data for rank plots
        rank_data = numpyro.diagnostics.rank_plot_data(samples, ['theta', 'sigma'])
        
        # Use with external plotting library
        import matplotlib.pyplot as plt
        for param, ranks in rank_data.items():
            plt.figure()
            for chain_ranks in ranks:
                plt.hist(chain_ranks, alpha=0.5, bins=50)
            plt.title(f"Rank plot for {param}")
    """

Diagnostic Utilities

Helper functions for diagnostic computations.

def within_chain_variance(x: NDArray) -> NDArray:
    """
    Compute within-chain variance for R-hat calculation.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...)
        
    Returns:
        Within-chain variance for each parameter
    """

def between_chain_variance(x: NDArray) -> NDArray:
    """
    Compute between-chain variance for R-hat calculation.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...)
        
    Returns:
        Between-chain variance for each parameter
    """

def integrated_autocorr_time(x: NDArray, c: float = 5.0, 
                            tol: float = 50.0, quiet: bool = False) -> float:
    """
    Compute integrated autocorrelation time.
    
    Estimates the correlation time by integrating the autocorrelation function
    until it becomes unreliable.
    
    Args:
        x: Time series data
        c: Window size multiplier for automatic windowing
        tol: Tolerance for unreliable estimates
        quiet: Whether to suppress warnings
        
    Returns:
        Integrated autocorrelation time
        
    Usage:
        # Estimate correlation time
        flat_samples = samples['theta'].reshape(-1)
        tau = numpyro.diagnostics.integrated_autocorr_time(flat_samples)
        print(f"Autocorrelation time: {tau:.2f}")
        
        # Rule of thumb: need at least 50*tau samples for reliable estimates
        min_samples = 50 * tau
        actual_samples = len(flat_samples)
        print(f"Recommended samples: {min_samples:.0f}, Actual: {actual_samples}")
    """

def compute_chain_statistics(x: NDArray) -> dict:
    """
    Compute comprehensive statistics for individual chains.
    
    Args:
        x: MCMC samples with shape (num_chains, num_samples, ...)
        
    Returns:
        Dictionary with statistics for each chain
        
    Usage:
        # Analyze individual chain performance
        chain_stats = numpyro.diagnostics.compute_chain_statistics(samples['theta'])
        
        for chain_id, stats in chain_stats.items():
            print(f"Chain {chain_id}:")
            print(f"  Mean: {stats['mean']:.3f}")
            print(f"  Variance: {stats['var']:.3f}")
            print(f"  ESS: {stats['ess']:.1f}")
    """

Usage Examples

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpyro.diagnostics as diagnostics
import jax.numpy as jnp
from jax import random

# Comprehensive diagnostic workflow
def diagnostic_workflow_example():
    # Define a simple model
    def model(x, y=None):
        alpha = numpyro.sample("alpha", dist.Normal(0, 1))
        beta = numpyro.sample("beta", dist.Normal(0, 1))
        sigma = numpyro.sample("sigma", dist.Exponential(1))
        
        mu = alpha + beta * x
        with numpyro.plate("data", len(x)):
            numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
    
    # Generate synthetic data
    key = random.PRNGKey(0)
    n_data = 100
    x = jnp.linspace(0, 1, n_data)
    true_alpha, true_beta, true_sigma = 1.0, 2.0, 0.1
    y = true_alpha + true_beta * x + true_sigma * random.normal(key, (n_data,))
    
    # Run MCMC with multiple chains
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
    mcmc.run(random.PRNGKey(1), x, y)
    
    # Get samples grouped by chain for diagnostics
    samples = mcmc.get_samples(group_by_chain=True)
    
    print("=== MCMC Diagnostic Report ===")
    
    # 1. Print comprehensive summary
    print("\\n1. Summary Statistics:")
    diagnostics.print_summary(samples, prob=0.95)
    
    # 2. Check convergence with R-hat
    print("\\n2. Convergence Diagnostics:")
    for param_name, param_samples in samples.items():
        rhat = diagnostics.gelman_rubin(param_samples)
        split_rhat = diagnostics.split_gelman_rubin(param_samples)
        
        print(f"{param_name}:")
        print(f"  R-hat: {rhat:.4f}")
        print(f"  Split R-hat: {split_rhat:.4f}")
        print(f"  Converged (R-hat < 1.1): {rhat < 1.1}")
    
    # 3. Assess sampling efficiency
    print("\\n3. Sampling Efficiency:")
    total_samples = samples['alpha'].shape[0] * samples['alpha'].shape[1]
    
    for param_name, param_samples in samples.items():
        ess = diagnostics.effective_sample_size(param_samples)
        efficiency = ess / total_samples
        
        print(f"{param_name}:")
        print(f"  ESS: {ess:.1f}")
        print(f"  Efficiency: {efficiency:.2%}")
        print(f"  Good ESS (>400): {ess > 400}")
    
    # 4. Posterior intervals
    print("\\n4. Posterior Intervals:")
    flat_samples = {k: v.reshape(-1) for k, v in samples.items()}
    
    for param_name, param_samples in flat_samples.items():
        hpdi_90 = diagnostics.hpdi(param_samples, prob=0.9)
        hpdi_95 = diagnostics.hpdi(param_samples, prob=0.95)
        
        print(f"{param_name}:")
        print(f"  90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")
        print(f"  95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")
    
    # 5. Autocorrelation analysis
    print("\\n5. Autocorrelation Analysis:")
    for param_name, param_samples in flat_samples.items():
        autocorr = diagnostics.autocorrelation(param_samples)
        tau = diagnostics.integrated_autocorr_time(param_samples, quiet=True)
        
        print(f"{param_name}:")
        print(f"  Autocorr time: {tau:.2f}")
        print(f"  Recommended min samples: {50 * tau:.0f}")
        print(f"  Actual samples: {len(param_samples)}")
    
    return samples

# Chain-specific diagnostics
def chain_analysis_example():
    # Assume we have samples from previous example
    # samples = ... (from MCMC run)
    
    # Analyze individual chains
    param_samples = samples['alpha']  # Shape: (num_chains, num_samples)
    
    print("=== Individual Chain Analysis ===")
    
    # Split by chain and analyze separately
    chain_samples = diagnostics.split_by_chain(param_samples)
    
    for i, chain in enumerate(chain_samples):
        mean_i = jnp.mean(chain)
        std_i = jnp.std(chain)
        autocorr_i = diagnostics.autocorrelation(chain)
        
        print(f"\\nChain {i}:")
        print(f"  Mean: {mean_i:.4f}")
        print(f"  Std: {std_i:.4f}")
        print(f"  First 5 autocorr values: {autocorr_i[:5]}")
    
    # Compare within vs between chain variance
    within_var = diagnostics.within_chain_variance(param_samples)
    between_var = diagnostics.between_chain_variance(param_samples)
    
    print(f"\\nVariance Analysis:")
    print(f"  Within-chain variance: {within_var:.6f}")
    print(f"  Between-chain variance: {between_var:.6f}")
    print(f"  Ratio (should be ~1): {between_var / within_var:.4f}")

# Diagnostic-driven sampling strategy
def adaptive_sampling_example():
    """Example of using diagnostics to determine sampling requirements."""
    
    def model():
        # Deliberately create a challenging posterior
        x = numpyro.sample("x", dist.Normal(0, 1))
        y = numpyro.sample("y", dist.Normal(x**2, 0.1))  # Non-linear relationship
    
    # Start with small number of samples
    initial_samples = 500
    target_ess = 400
    max_iterations = 5
    
    for iteration in range(max_iterations):
        print(f"\\n--- Iteration {iteration + 1} ---")
        
        # Run MCMC
        mcmc = MCMC(NUTS(model), 
                   num_warmup=initial_samples, 
                   num_samples=initial_samples,
                   num_chains=4)
        mcmc.run(random.PRNGKey(iteration))
        
        samples = mcmc.get_samples(group_by_chain=True)
        
        # Check diagnostics
        rhat = diagnostics.gelman_rubin(samples['x'])
        ess = diagnostics.effective_sample_size(samples['x'])
        
        print(f"Current samples per chain: {initial_samples}")
        print(f"R-hat: {rhat:.4f}")
        print(f"ESS: {ess:.1f}")
        
        # Check if we meet convergence criteria
        converged = rhat < 1.1
        sufficient_ess = ess > target_ess
        
        if converged and sufficient_ess:
            print(f"✓ Convergence achieved!")
            break
        elif not converged:
            print(f"✗ Poor convergence (R-hat = {rhat:.4f})")
            initial_samples = int(initial_samples * 1.5)  # Increase samples
        elif not sufficient_ess:
            print(f"✗ Insufficient ESS ({ess:.1f} < {target_ess})")
            initial_samples = int(initial_samples * 1.2)  # Modest increase
    
    return samples

Types

from typing import Optional, Union, Dict, Any, List
from jax import Array
import jax.numpy as jnp

NDArray = jnp.ndarray
ArrayLike = Union[Array, NDArray, float, int]
Samples = Dict[str, NDArray]

class DiagnosticResult:
    """Base class for diagnostic results."""
    pass

class SummaryStats:
    """Summary statistics for a parameter."""
    mean: float
    std: float
    median: float
    mad: float  # Median absolute deviation
    hpdi_lower: float
    hpdi_upper: float
    n_eff: float  # Effective sample size
    r_hat: float  # R-hat statistic

class ConvergenceDiagnostic:
    """Convergence diagnostic results."""
    r_hat: NDArray
    split_r_hat: NDArray
    converged: bool
    potential_scale_reduction: NDArray

class EfficiencyDiagnostic:
    """Sampling efficiency diagnostic results."""
    effective_sample_size: NDArray
    autocorrelation_time: NDArray
    efficiency_ratio: NDArray
    
class AutocorrelationResult:
    """Autocorrelation analysis results."""
    autocorr: NDArray
    autocov: NDArray
    integrated_time: float
    
class ChainStatistics:
    """Statistics for individual MCMC chains."""
    chain_id: int
    mean: NDArray
    variance: NDArray
    effective_sample_size: NDArray
    autocorrelation_time: float

# Function type signatures
ConvergenceFunction = Callable[[NDArray], NDArray]
SummaryFunction = Callable[[NDArray], Dict[str, Any]]
DiagnosticFunction = Callable[[NDArray], DiagnosticResult]

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json