CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pymc

Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with PyTensor

Pending
Overview
Eval results
Files

stats.mddocs/

PyMC Statistics and Diagnostics

PyMC provides comprehensive statistical functions and convergence diagnostics for Bayesian analysis, primarily through integration with ArviZ. The library offers tools for model validation, convergence assessment, and posterior analysis.

Convergence Diagnostics

PyMC exposes key diagnostic functions from ArviZ for assessing MCMC convergence:

R-hat Statistic

def rhat(data, var_names=None, method='rank', dask_kwargs=None):
    """
    Compute R-hat convergence diagnostic (Gelman-Rubin statistic).
    
    Parameters:
    - data: InferenceData object or trace
    - var_names (list, optional): Variables to analyze
    - method (str): Method for computation ('rank', 'split', 'folded')
    - dask_kwargs (dict, optional): Dask computation options
    
    Returns:
    - rhat_values: R-hat statistics for each variable
    """

import pymc as pm

# Compute R-hat for all variables
with pm.Model() as model:
    # Model definition and sampling...
    trace = pm.sample()

rhat_stats = pm.rhat(trace)
print("R-hat diagnostics:")
for var, rhat_val in rhat_stats.items():
    print(f"  {var}: {rhat_val:.4f}")

# R-hat for specific variables only
rhat_subset = pm.rhat(trace, var_names=['alpha', 'beta'])

# Check convergence (R-hat should be < 1.01)
converged = all(rhat_val < 1.01 for rhat_val in rhat_stats.values())

Effective Sample Size

def effective_sample_size(data, var_names=None, method='bulk', 
                         relative=False, dask_kwargs=None):
    """
    Compute effective sample size (ESS).
    
    Parameters:
    - data: InferenceData object or trace
    - var_names (list, optional): Variables to analyze
    - method (str): ESS method ('bulk', 'tail', 'quantile', 'mean', 'sd')
    - relative (bool): Return relative ESS (ESS/N)
    - dask_kwargs (dict, optional): Dask computation options
    
    Returns:
    - ess_values: Effective sample size for each variable
    """

# Bulk ESS (measures efficiency in central posterior)
bulk_ess = pm.ess(trace, method='bulk')

# Tail ESS (measures efficiency in posterior tails)  
tail_ess = pm.ess(trace, method='tail')

# Relative ESS (as fraction of total samples)
rel_ess = pm.ess(trace, relative=True)

print("Effective Sample Size (bulk):")
for var, ess_val in bulk_ess.items():
    print(f"  {var}: {ess_val:.0f}")

# Check adequacy (ESS should be > 400 for reliable inference)
adequate_ess = all(ess_val > 400 for ess_val in bulk_ess.values())

Monte Carlo Standard Error

def mcse(data, var_names=None, method='mean', dask_kwargs=None):
    """
    Compute Monte Carlo standard error.
    
    Parameters:
    - data: InferenceData object or trace
    - var_names (list, optional): Variables to analyze  
    - method (str): Statistic to compute MCSE for ('mean', 'sd', 'quantile')
    - dask_kwargs (dict, optional): Dask computation options
    
    Returns:
    - mcse_values: Monte Carlo standard errors
    """

# MCSE for posterior means
mcse_mean = pm.mcse(trace, method='mean')

# MCSE for posterior standard deviations
mcse_sd = pm.mcse(trace, method='sd')

# MCSE for quantiles
mcse_quantile = pm.mcse(trace, method='quantile')

print("Monte Carlo Standard Error (mean):")
for var, mcse_val in mcse_mean.items():
    print(f"  {var}: {mcse_val:.6f}")

Model Comparison

Leave-One-Out Cross-Validation

def loo(data, var_name=None, reff=None, scale=None, pointwise=False, 
        dask_kwargs=None):
    """
    Compute leave-one-out (LOO) cross-validation using Pareto smoothed importance sampling.
    
    Parameters:
    - data: InferenceData object with log_likelihood group
    - var_name (str, optional): Variable name for likelihood
    - reff (array, optional): Relative effective sample size
    - scale (str): Scale for IC ('log', 'negative_log', 'deviance')
    - pointwise (bool): Return pointwise LOO values
    - dask_kwargs (dict, optional): Dask computation options
    
    Returns:
    - loo_result: LOO-CV results with ELPD, SE, and diagnostics
    """

# Compute log-likelihood for LOO
with pm.Model() as model:
    # Model definition...
    trace = pm.sample()
    
    # Compute log-likelihood
    log_likelihood = pm.compute_log_likelihood(trace, model=model)

# LOO cross-validation
loo_result = pm.loo(trace)
print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")
print(f"LOO IC: {loo_result.loo:.2f}")
print(f"p_loo (effective parameters): {loo_result.p_loo:.2f}")

# Check Pareto k diagnostic
high_k = (loo_result.pareto_k > 0.7).sum()
if high_k > 0:
    print(f"Warning: {high_k} observations have high Pareto k values")

Watanabe-Akaike Information Criterion

def waic(data, var_name=None, scale=None, pointwise=False, dask_kwargs=None):
    """
    Compute Watanabe-Akaike Information Criterion (WAIC).
    
    Parameters:
    - data: InferenceData object with log_likelihood group
    - var_name (str, optional): Variable name for likelihood
    - scale (str): Scale for IC ('log', 'negative_log', 'deviance')  
    - pointwise (bool): Return pointwise WAIC values
    - dask_kwargs (dict, optional): Dask computation options
    
    Returns:
    - waic_result: WAIC results with ELPD, SE, and effective parameters
    """

# WAIC computation
waic_result = pm.waic(trace)
print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")
print(f"WAIC: {waic_result.waic:.2f}")
print(f"p_waic (effective parameters): {waic_result.p_waic:.2f}")

# Pointwise WAIC for outlier detection
waic_pointwise = pm.waic(trace, pointwise=True)
outlier_threshold = waic_pointwise.waic_i.mean() + 2 * waic_pointwise.waic_i.std()
outliers = waic_pointwise.waic_i > outlier_threshold
print(f"Potential outliers: {outliers.sum()} observations")

Model Comparison Framework

def compare(compare_dict, ic=None, method='stacking', b_samples=1000,
            alpha=0.05, seed=None, scale=None):
    """
    Compare models using information criteria.
    
    Parameters:
    - compare_dict (dict): Dictionary of {model_name: InferenceData}
    - ic (str): Information criterion ('loo', 'waic')
    - method (str): Comparison method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')
    - b_samples (int): Bootstrap samples for SE estimation
    - alpha (float): Significance level for intervals
    - seed (int): Random seed
    - scale (str): Scale for IC reporting
    
    Returns:
    - comparison_df: DataFrame with model comparison results
    """

# Compare multiple models
models = {
    'linear': linear_trace,
    'quadratic': quadratic_trace, 
    'cubic': cubic_trace
}

comparison = pm.compare(models, ic='loo')
print("Model Comparison (LOO):")
print(comparison)

# Model weights from stacking
print("\nModel weights:")
for model, weight in zip(comparison.index, comparison.weight):
    print(f"  {model}: {weight:.3f}")

# Automatically select best model
best_model = comparison.index[0]  # First row is best
print(f"\nBest model: {best_model}")

Log-Likelihood and Prior Computation

Log-Likelihood Calculation

def compute_log_likelihood(idata=None, *, model=None, var_names=None, 
                          extend_inferencedata=True, progressbar=True):
    """
    Compute pointwise log-likelihood values.
    
    Parameters:
    - idata: InferenceData object with posterior samples
    - model: PyMC model (default: current context)
    - var_names (list, optional): Observed variables to compute likelihood for
    - extend_inferencedata (bool): Add results to InferenceData
    - progressbar (bool): Show progress bar
    
    Returns:
    - log_likelihood: Log-likelihood values for each observation and posterior sample
    """

with pm.Model() as model:
    # Model with likelihood...
    trace = pm.sample()

# Compute log-likelihood
log_lik = pm.compute_log_likelihood(trace, model=model)

# Access log-likelihood values
ll_values = trace.log_likelihood  # Added to InferenceData
print(f"Log-likelihood shape: {ll_values['y_obs'].shape}")  # (chains, draws, observations)

# Total log-likelihood per sample
total_ll = ll_values['y_obs'].sum(dim='y_obs_dim_0')
print(f"Total log-likelihood range: {total_ll.min():.2f} to {total_ll.max():.2f}")

Log-Prior Calculation

def compute_log_prior(idata=None, *, model=None, var_names=None,
                     extend_inferencedata=True, progressbar=True):
    """
    Compute log-prior density values.
    
    Parameters:
    - idata: InferenceData object with posterior samples
    - model: PyMC model (default: current context)
    - var_names (list, optional): Variables to compute log-prior for
    - extend_inferencedata (bool): Add results to InferenceData  
    - progressbar (bool): Show progress bar
    
    Returns:
    - log_prior: Log-prior values for each variable and posterior sample
    """

# Compute log-prior
log_prior = pm.compute_log_prior(trace, model=model)

# Access log-prior values
prior_values = trace.log_prior
print("Log-prior components:")
for var_name in prior_values.data_vars:
    values = prior_values[var_name]
    print(f"  {var_name}: mean = {values.mean():.3f}, std = {values.std():.3f}")

# Total log-prior per sample
total_prior = sum(prior_values[var].sum() for var in prior_values.data_vars)

Posterior Analysis Utilities

Summary Statistics

# Summary statistics through ArviZ integration
summary_stats = pm.summary(trace, var_names=['alpha', 'beta'])
print("Posterior Summary:")
print(summary_stats)

# Custom summary with specific quantiles
custom_summary = pm.summary(trace, 
                           stat_funcs={'median': np.median,
                                     'mad': lambda x: np.median(np.abs(x - np.median(x)))},
                           extend=True)

# Round summary for reporting
rounded_summary = pm.summary(trace, round_to=3)

Posterior Predictive Checks

# Posterior predictive sampling for model checking
with pm.Model() as model:
    # Model definition...
    trace = pm.sample()
    
    # Posterior predictive samples
    post_pred = pm.sample_posterior_predictive(trace, predictions=True)

# Compare observed vs predicted
observed = post_pred.observed_data['y_obs']
predicted = post_pred.posterior_predictive['y_obs']

# T-test statistic for checking
def t_statistic(y):
    return (y.mean() - observed.mean()) / (y.std() / np.sqrt(len(y)))

# Compute test statistic for observed and predicted
t_obs = t_statistic(observed.values)
t_pred = [t_statistic(pred_sample) for pred_sample in predicted.values.reshape(-1, len(observed))]

# Bayesian p-value
p_value = np.mean(np.abs(t_pred) >= np.abs(t_obs))
print(f"Bayesian p-value for mean difference: {p_value:.3f}")

Advanced Diagnostics

Energy Diagnostics

# Access sampler statistics for energy diagnostics
sampler_stats = trace.get_sampler_stats()

# Energy statistics
energy = sampler_stats['energy']
energy_diff = np.diff(energy, axis=1)  # Energy differences between steps

# Check for energy problems
mean_energy_diff = energy_diff.mean()
if abs(mean_energy_diff) > 0.2:
    print(f"Warning: Large energy differences (mean = {mean_energy_diff:.3f})")

# Divergences
diverging = sampler_stats['diverging']
n_diverging = diverging.sum()
if n_diverging > 0:
    print(f"Warning: {n_diverging} divergent transitions detected")
    
# Tree depth
treedepth = sampler_stats['treedepth'] 
max_treedepth = sampler_stats['max_treedepth']
saturated_trees = (treedepth >= max_treedepth).sum()
if saturated_trees > 0:
    print(f"Warning: {saturated_trees} saturated trees (increase max_treedepth)")

Custom Diagnostics

def compute_split_rhat(trace, var_name):
    """Compute split R-hat manually for understanding."""
    
    # Get samples for variable
    samples = trace.posterior[var_name].values  # Shape: (chains, draws, ...)
    n_chains, n_draws = samples.shape[:2]
    
    # Split each chain in half
    first_half = samples[:, :n_draws//2]
    second_half = samples[:, n_draws//2:]
    
    # Combine split chains
    split_samples = np.concatenate([first_half, second_half], axis=0)
    
    # Between-chain variance
    chain_means = split_samples.mean(axis=1)
    overall_mean = chain_means.mean()
    B = n_draws//2 * np.var(chain_means, ddof=1)
    
    # Within-chain variance  
    chain_vars = split_samples.var(axis=1, ddof=1)
    W = chain_vars.mean()
    
    # Marginal posterior variance estimate
    var_hat = (n_draws//2 - 1) / (n_draws//2) * W + B / (n_draws//2)
    
    # R-hat
    rhat = np.sqrt(var_hat / W)
    
    return rhat

# Usage
manual_rhat = compute_split_rhat(trace, 'alpha')
print(f"Manual R-hat calculation: {manual_rhat:.4f}")

Rank Normalization Diagnostics

def rank_normalized_split_rhat(data, var_names=None):
    """
    Compute rank-normalized split R-hat (more robust version).
    
    Parameters:
    - data: InferenceData object
    - var_names (list, optional): Variables to analyze
    
    Returns:
    - rhat_rank: Rank-normalized R-hat values
    """

# More robust R-hat using rank normalization
rhat_rank = pm.rank_normalized_split_rhat(trace)
print("Rank-normalized R-hat:")
for var, rhat_val in rhat_rank.items():
    print(f"  {var}: {rhat_val:.4f}")
    if rhat_val > 1.01:
        print(f"    Warning: {var} may not have converged")

Diagnostic Workflows

Comprehensive Convergence Check

def full_convergence_check(trace, model_name="Model"):
    """Comprehensive convergence assessment."""
    
    print(f"=== Convergence Diagnostics for {model_name} ===")
    
    # R-hat
    rhat_vals = pm.rhat(trace)
    max_rhat = max(rhat_vals.values())
    print(f"Max R-hat: {max_rhat:.4f}")
    
    # Effective sample size
    ess_bulk = pm.ess(trace, method='bulk')
    ess_tail = pm.ess(trace, method='tail')
    min_ess_bulk = min(ess_bulk.values())
    min_ess_tail = min(ess_tail.values())
    print(f"Min ESS (bulk): {min_ess_bulk:.0f}")
    print(f"Min ESS (tail): {min_ess_tail:.0f}")
    
    # Sampler diagnostics
    n_diverging = trace.get_sampler_stats('diverging').sum()
    print(f"Diverging transitions: {n_diverging}")
    
    # Overall assessment
    converged = (max_rhat < 1.01 and min_ess_bulk > 400 and 
                min_ess_tail > 400 and n_diverging == 0)
    
    print(f"Overall convergence: {'✓ PASS' if converged else '✗ FAIL'}")
    
    return converged

# Usage
convergence_ok = full_convergence_check(trace, "Regression Model")

Model Quality Assessment

def assess_model_quality(trace, observed_data, model):
    """Comprehensive model quality assessment."""
    
    print("=== Model Quality Assessment ===")
    
    # Information criteria
    loo_result = pm.loo(trace)
    waic_result = pm.waic(trace)
    
    print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")
    print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")
    
    # Check for high Pareto k values
    high_k = (loo_result.pareto_k > 0.7).sum()
    if high_k > 0:
        print(f"Warning: {high_k} observations have unreliable LOO estimates")
    
    # Posterior predictive checks
    post_pred = pm.sample_posterior_predictive(trace, model=model)
    
    # Simple residual check
    y_obs = observed_data
    y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
    residuals = y_obs - y_pred_mean
    
    print(f"Mean absolute residual: {np.abs(residuals).mean():.3f}")
    print(f"Residual std: {residuals.std():.3f}")
    
    return loo_result, waic_result, residuals

# Usage
loo, waic, residuals = assess_model_quality(trace, y_data, model)

PyMC's statistics and diagnostics framework, built on ArviZ integration, provides essential tools for validating Bayesian models and ensuring reliable inference results.

Install with Tessl CLI

npx tessl i tessl/pypi-pymc

docs

data.md

distributions.md

gp.md

index.md

math.md

model.md

ode.md

sampling.md

stats.md

variational.md

tile.json