CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pymc3

Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with Theano

68

0.94x
Overview
Eval results
Files

stats-plots.mddocs/

Statistics and Plotting (ArviZ Integration)

PyMC3 integrates tightly with ArviZ for comprehensive Bayesian analysis, model diagnostics, and publication-quality visualizations. The stats and plots modules delegate to ArviZ while providing PyMC3-specific functionality and convenient aliases for common workflows.

Capabilities

Convergence Diagnostics

Functions for assessing MCMC convergence and sample quality through pymc3.stats.*.

def r_hat(trace, var_names=None, method='rank'):
    """
    Compute R-hat convergence diagnostic.
    
    Measures between-chain and within-chain variance to assess
    convergence across multiple MCMC chains. Values close to 1.0
    indicate good convergence.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - var_names: list, variables to analyze (all if None)
    - method: str, computation method ('rank', 'split', 'folded')
    
    Returns:
    - dict or array: R-hat values by variable
    
    Interpretation:
    - R_hat < 1.01: Excellent convergence
    - R_hat < 1.1: Good convergence  
    - R_hat > 1.1: Poor convergence, need more samples
    """

def ess(trace, var_names=None, method='bulk'):
    """
    Compute effective sample size.
    
    Estimates the number of independent samples, accounting for
    autocorrelation in MCMC chains. Higher values indicate better
    mixing and more efficient sampling.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - var_names: list, variables to analyze
    - method: str, ESS type ('bulk', 'tail', 'mean', 'sd', 'quantile')
    
    Returns:
    - dict or array: effective sample sizes
    
    Guidelines:
    - ESS > 400: Generally sufficient for posterior inference
    - ESS > 100: Minimum for reasonable estimates
    - ESS < 100: Increase sampling or improve model
    """

def mcse(trace, var_names=None, method='mean', prob=None):
    """
    Monte Carlo standard error of estimates.
    
    Measures uncertainty in posterior estimates due to finite
    sampling, helping determine if more samples are needed.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - var_names: list, variables to analyze
    - method: str, estimate type ('mean', 'sd', 'quantile')
    - prob: float, probability for quantile MCSE
    
    Returns:
    - dict: MCSE values by variable
    """

def geweke(trace, var_names=None, first=0.1, last=0.5, intervals=20):
    """
    Geweke convergence diagnostic.
    
    Compares means from early and late portions of chains to
    assess within-chain convergence and stationarity.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - var_names: list, variables to test
    - first: float, fraction for early portion
    - last: float, fraction for late portion  
    - intervals: int, number of test intervals
    
    Returns:
    - dict: Geweke statistics by variable
    """

Model Comparison

Information criteria and cross-validation for Bayesian model selection.

def compare(models, ic='waic', method='stacking', b_samples=1000, 
           alpha=1, seed=None, round_to=2):
    """
    Compare multiple models using information criteria.
    
    Ranks models by predictive performance using WAIC, LOO-CV,
    or other criteria, with model weights and standard errors.
    
    Parameters:
    - models: dict, mapping model names to InferenceData objects
    - ic: str, information criterion ('waic', 'loo')
    - method: str, weighting method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')
    - b_samples: int, samples for Bootstrap weighting
    - alpha: float, concentration parameter for pseudo-BMA
    - seed: int, random seed for reproducibility
    - round_to: int, decimal places for results
    
    Returns:
    - DataFrame: model comparison results with ranks and weights
    
    Columns:
    - rank: model ranking (0 = best)
    - elpd_*: expected log pointwise predictive density
    - p_*: effective number of parameters
    - d_*: difference from best model
    - weight: model averaging weights
    - se: standard error of differences
    - dse: standard error of difference from best
    """

def waic(trace, model=None, pointwise=False, scale='deviance'):
    """
    Watanabe-Akaike Information Criterion.
    
    Estimates out-of-sample predictive performance using
    within-sample log-likelihood with penalty for overfitting.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - model: Model, model context (current if None)
    - pointwise: bool, return pointwise WAIC values
    - scale: str, return scale ('deviance' or 'log')
    
    Returns:
    - ELPDData: WAIC results with components and diagnostics
    
    Components:
    - elpd_waic: expected log pointwise predictive density
    - p_waic: effective number of parameters
    - waic: -2 * elpd_waic (lower is better)
    - se: standard error of WAIC
    """

def loo(trace, model=None, pointwise=False, reff=None, scale='deviance'):
    """
    Pareto Smoothed Importance Sampling Leave-One-Out Cross-Validation.
    
    Estimates out-of-sample performance using leave-one-out
    cross-validation approximated by importance sampling.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - model: Model, model context
    - pointwise: bool, return pointwise LOO values
    - reff: array, relative effective sample sizes
    - scale: str, return scale ('deviance' or 'log')
    
    Returns:
    - ELPDData: LOO-CV results with Pareto diagnostics
    
    Diagnostics:
    - Pareto k < 0.5: Good approximation
    - Pareto k < 0.7: Okay approximation
    - Pareto k > 0.7: Poor approximation, use exact CV
    """

def loo_pit(idata, y=None, y_hat=None, log_weights=None):
    """
    Leave-one-out probability integral transform.
    
    Calibration check for posterior predictive distributions
    using LOO-PIT values that should be uniform if well-calibrated.
    
    Parameters:
    - idata: InferenceData, posterior and predictions
    - y: array, observed data (from idata if None)
    - y_hat: array, posterior predictive samples
    - log_weights: array, importance sampling weights
    
    Returns:
    - array: LOO-PIT values for calibration assessment
    """

Summary Statistics

Posterior summary and descriptive statistics.

def summary(trace, var_names=None, stat_funcs=None, extend=True, 
           credible_interval=0.94, round_to=2, kind='stats'):
    """
    Comprehensive posterior summary statistics.
    
    Provides means, standard deviations, credible intervals,
    and convergence diagnostics for all model parameters.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - var_names: list, variables to summarize (all if None)
    - stat_funcs: dict, custom summary functions
    - extend: bool, include convergence diagnostics
    - credible_interval: float, credible interval width
    - round_to: int, decimal places
    - kind: str, summary type ('stats', 'diagnostics')
    
    Returns:
    - DataFrame: comprehensive parameter summary
    
    Columns:
    - mean: posterior mean
    - sd: posterior standard deviation  
    - hdi_3%/hdi_97%: highest density interval bounds
    - mcse_mean: MCSE of mean
    - mcse_sd: MCSE of standard deviation
    - ess_bulk/ess_tail: effective sample sizes
    - r_hat: R-hat convergence diagnostic
    """

def describe(trace, var_names=None, include_ci=True, ci_prob=0.94):
    """
    Descriptive statistics for posterior distributions.
    
    Parameters:
    - trace: InferenceData or MultiTrace, posterior samples
    - var_names: list, variables to describe
    - include_ci: bool, include credible intervals
    - ci_prob: float, credible interval probability
    
    Returns:
    - DataFrame: descriptive statistics
    """

def quantiles(x, qlist=(0.025, 0.25, 0.5, 0.75, 0.975)):
    """
    Compute quantiles of posterior samples.
    
    Parameters:
    - x: array, samples
    - qlist: tuple, quantile probabilities
    
    Returns:
    - dict: quantile values
    """

def hdi(x, credible_interval=0.94, circular=False):
    """
    Highest Density Interval (HDI).
    
    Computes the shortest interval containing specified
    probability mass of the posterior distribution.
    
    Parameters:
    - x: array, posterior samples
    - credible_interval: float, interval probability
    - circular: bool, circular data (angles)
    
    Returns:
    - array: [lower_bound, upper_bound]
    """

Posterior Analysis

Advanced posterior analysis and derived quantities.

def autocorr(trace, var_names=None, max_lag=100):
    """
    Autocorrelation function of MCMC chains.
    
    Measures correlation between samples at different lags
    to assess mixing and effective sample size.
    
    Parameters:
    - trace: InferenceData or MultiTrace, samples
    - var_names: list, variables to analyze
    - max_lag: int, maximum lag to compute
    
    Returns:
    - dict: autocorrelation functions by variable
    """

def make_ufunc(func, nin=1, nout=1, **kwargs):
    """
    Create universal function for posterior analysis.
    
    Converts regular functions into universal functions
    that work efficiently on posterior sample arrays.
    
    Parameters:
    - func: callable, function to convert
    - nin: int, number of inputs
    - nout: int, number of outputs
    - kwargs: additional ufunc arguments
    
    Returns:
    - ufunc: universal function
    """

def from_dict(posterior_dict, coords=None, dims=None):
    """
    Create InferenceData from dictionary of arrays.
    
    Parameters:
    - posterior_dict: dict, posterior samples by variable
    - coords: dict, coordinate values
    - dims: dict, dimension names by variable
    
    Returns:
    - InferenceData: formatted inference data
    """

Plotting Functions

Comprehensive visualization capabilities through ArviZ integration via pymc3.plots.*.

def plot_trace(trace, var_names=None, coords=None, divergences='auto',
              figsize=None, rug=False, lines=None, compact=True, 
              combined=False, legend=False, plot_kwargs=None, 
              fill_kwargs=None, rug_kwargs=None, **kwargs):
    """
    Trace plots showing MCMC sampling paths and marginal distributions.
    
    Essential diagnostic plot combining time series of samples
    with marginal posterior distributions for visual convergence assessment.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to plot
    - coords: dict, coordinate slices for multidimensional variables
    - divergences: str or bool, highlight divergent transitions
    - figsize: tuple, figure size
    - rug: bool, add rug plot to marginals
    - lines: dict, reference lines to overlay
    - compact: bool, compact layout
    - combined: bool, combine all chains
    - legend: bool, show chain legend
    - plot_kwargs: dict, line plot arguments
    - fill_kwargs: dict, density fill arguments
    - rug_kwargs: dict, rug plot arguments
    
    Returns:
    - matplotlib axes: plot axes array
    """

def plot_posterior(trace, var_names=None, coords=None, figsize=None,
                  textsize=None, hdi_prob=0.94, multimodal=False, 
                  skipna=False, ref_val=None, rope=None, point_estimate='mean',
                  round_to=2, credible_interval=None, **kwargs):
    """
    Posterior distribution plots with summary statistics.
    
    Shows marginal posterior distributions with credible intervals,
    point estimates, and optional reference values or ROPE.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to plot
    - coords: dict, coordinate selections
    - figsize: tuple, figure size
    - textsize: float, text size for annotations
    - hdi_prob: float, HDI probability
    - multimodal: bool, detect and handle multimodal distributions
    - skipna: bool, skip missing values
    - ref_val: dict, reference values by variable
    - rope: dict, region of practical equivalence bounds
    - point_estimate: str, point estimate type ('mean', 'median', 'mode')
    - round_to: int, decimal places for annotations
    - credible_interval: float, deprecated alias for hdi_prob
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_forest(trace, var_names=None, coords=None, figsize=None,
               textsize=None, ropestyle='top', ropes=None, credible_interval=0.94,
               quartiles=True, r_hat=True, ess=True, combined=False, 
               colors='cycle', **kwargs):
    """
    Forest plot showing parameter estimates with uncertainty intervals.
    
    Horizontal plot displaying point estimates and credible intervals
    for multiple parameters, useful for coefficient comparison.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to include
    - coords: dict, coordinate selections
    - figsize: tuple, figure size
    - textsize: float, text size
    - ropestyle: str, ROPE display style ('top', 'bottom', None)
    - ropes: dict, ROPE bounds by variable
    - credible_interval: float, interval probability
    - quartiles: bool, show quartile markers
    - r_hat: bool, show R-hat values
    - ess: bool, show effective sample size
    - combined: bool, combine chains before plotting
    - colors: str or list, color specification
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_autocorr(trace, var_names=None, coords=None, figsize=None,
                 textsize=None, max_lag=100, combined=False, **kwargs):
    """
    Autocorrelation plots for assessing chain mixing.
    
    Shows autocorrelation function to diagnose slow mixing
    and estimate effective sample sizes visually.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to plot
    - coords: dict, coordinate selections
    - figsize: tuple, figure size
    - textsize: float, text size
    - max_lag: int, maximum lag to plot
    - combined: bool, combine chains
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_rank(trace, var_names=None, coords=None, figsize=None,
             bins=20, kind='bars', **kwargs):
    """
    Rank plots for MCMC diagnostics.
    
    Shows rank statistics across chains to identify mixing
    problems and between-chain differences.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to plot
    - coords: dict, coordinate selections
    - figsize: tuple, figure size
    - bins: int, number of rank bins
    - kind: str, plot type ('bars', 'vlines')
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_energy(trace, figsize=None, **kwargs):
    """
    Energy plot for HMC/NUTS diagnostics.
    
    Compares energy distributions between tuning and sampling
    phases to identify potential sampling problems.
    
    Parameters:
    - trace: InferenceData, posterior samples with energy info
    - figsize: tuple, figure size
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_pair(trace, var_names=None, coords=None, figsize=None,
             textsize=None, kind='scatter', gridsize='auto', 
             colorbar=True, divergences=False, **kwargs):
    """
    Pairwise parameter plots showing correlations and structure.
    
    Matrix of bivariate plots revealing posterior correlations,
    multimodality, and geometric structure.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to include
    - coords: dict, coordinate selections  
    - figsize: tuple, figure size
    - textsize: float, text size
    - kind: str, plot type ('scatter', 'kde', 'hexbin')
    - gridsize: int or 'auto', grid resolution for kde/hexbin
    - colorbar: bool, show colorbar for density plots
    - divergences: bool, highlight divergent samples
    
    Returns:
    - matplotlib axes: plot axes matrix
    """

def plot_parallel(trace, var_names=None, coords=None, figsize=None,
                 colornd='k', colord='r', shadend=0.025, **kwargs):
    """
    Parallel coordinates plot for high-dimensional visualization.
    
    Shows sample paths across multiple parameters to identify
    correlations and outliers in high-dimensional posteriors.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to include
    - coords: dict, coordinate selections
    - figsize: tuple, figure size
    - colornd: color for non-divergent samples
    - colord: color for divergent samples  
    - shadend: float, transparency for non-divergent samples
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_violin(trace, var_names=None, coords=None, figsize=None,
               textsize=None, credible_interval=0.94, quartiles=True,
               rug=False, **kwargs):
    """
    Violin plots showing posterior distribution shapes.
    
    Kernel density estimates with optional quartiles and
    credible intervals for comparing parameter distributions.
    
    Parameters:
    - trace: InferenceData, posterior samples
    - var_names: list, variables to plot
    - coords: dict, coordinate selections
    - figsize: tuple, figure size
    - textsize: float, text size
    - credible_interval: float, interval to mark
    - quartiles: bool, show quartile lines
    - rug: bool, add rug plot
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_kde(values, values2=None, cumulative=False, rug=False,
            label=None, bw='scott', adaptive=False, extend=True,
            gridsize=None, clip=None, alpha=0.7, **kwargs):
    """
    Kernel density estimation plots.
    
    Smooth density estimates for continuous distributions
    with options for cumulative plots and comparisons.
    
    Parameters:
    - values: array, samples to plot
    - values2: array, optional second sample for comparison
    - cumulative: bool, plot cumulative density
    - rug: bool, add rug plot
    - label: str, plot label
    - bw: str or float, bandwidth selection method
    - adaptive: bool, use adaptive bandwidth
    - extend: bool, extend domain beyond data range
    - gridsize: int, evaluation grid size
    - clip: tuple, domain bounds
    - alpha: float, transparency
    
    Returns:
    - matplotlib axes: plot axes
    """

Posterior Predictive Checking

Functions for model validation through posterior predictive distributions.

def plot_ppc(trace, kind='kde', alpha=0.05, figsize=None, textsize=None,
            data_pairs=None, var_names=None, coords=None, flatten=None,
            flatten_pp=None, num_pp_samples=100, random_seed=None, 
            jitter=None, mean=True, observed=True, **kwargs):
    """
    Posterior predictive check plots.
    
    Compares observed data with posterior predictive samples
    to assess model fit and identify systematic deviations.
    
    Parameters:
    - trace: InferenceData, with posterior_predictive group
    - kind: str, plot type ('kde', 'cumulative', 'scatter')
    - alpha: float, transparency for predictive samples
    - figsize: tuple, figure size
    - textsize: float, text size
    - data_pairs: dict, observed data by variable name
    - var_names: list, variables to plot
    - coords: dict, coordinate selections
    - flatten: list, dimensions to flatten
    - flatten_pp: list, posterior predictive dimensions to flatten
    - num_pp_samples: int, number of predictive samples to show
    - random_seed: int, random seed for sample selection
    - jitter: float, jitter amount for discrete data
    - mean: bool, show predictive mean
    - observed: bool, show observed data
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_loo_pit(idata, y=None, y_hat=None, log_weights=None, 
                ecdf=False, ecdf_fill=True, use_hdi=True, 
                credible_interval=0.99, figsize=None, **kwargs):
    """
    Leave-one-out probability integral transform plots.
    
    Diagnostic plots for posterior predictive calibration
    using LOO-PIT values that should be uniform if well-calibrated.
    
    Parameters:
    - idata: InferenceData, inference results
    - y: array, observed values
    - y_hat: array, posterior predictive samples
    - log_weights: array, importance weights
    - ecdf: bool, overlay empirical CDF
    - ecdf_fill: bool, fill ECDF confidence band
    - use_hdi: bool, use HDI for confidence bands
    - credible_interval: float, confidence level
    - figsize: tuple, figure size
    
    Returns:
    - matplotlib axes: plot axes
    """

Model Comparison Plots

Visualization for comparing multiple models.

def plot_compare(comp_df, insample_dev=True, plot_ic_diff=True, 
                order_by_rank=True, figsize=None, textsize=None, **kwargs):
    """
    Model comparison plot showing information criteria.
    
    Visual comparison of models using WAIC/LOO with
    standard errors and ranking information.
    
    Parameters:
    - comp_df: DataFrame, results from az.compare()
    - insample_dev: bool, plot in-sample deviance
    - plot_ic_diff: bool, plot differences from best model
    - order_by_rank: bool, order models by rank
    - figsize: tuple, figure size
    - textsize: float, text size
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_elpd(comp_df, xlabels=False, figsize=None, textsize=None, 
             color='C0', **kwargs):
    """
    Expected log predictive density comparison plot.
    
    Parameters:
    - comp_df: DataFrame, comparison results
    - xlabels: bool, show x-axis labels
    - figsize: tuple, figure size  
    - textsize: float, text size
    - color: color specification
    
    Returns:
    - matplotlib axes: plot axes
    """

def plot_khat(khats, bins=None, figsize=None, ax=None, **kwargs):
    """
    Pareto k diagnostic plot for LOO reliability.
    
    Shows distribution of Pareto k values to assess
    reliability of LOO approximation.
    
    Parameters:
    - khats: array, Pareto k values from loo()
    - bins: int, histogram bins
    - figsize: tuple, figure size
    - ax: matplotlib axes, existing axes
    
    Returns:
    - matplotlib axes: plot axes
    """

Usage Examples

Comprehensive Model Diagnostics

import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

# Example model and sampling
with pm.Model() as diagnostic_model:
    mu = pm.Normal('mu', mu=0, sigma=10)
    sigma = pm.HalfNormal('sigma', sigma=5)
    y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)
    
    # Sample with multiple chains for diagnostics
    trace = pm.sample(1000, tune=1000, chains=4, 
                     target_accept=0.95, return_inferencedata=True)

# Convergence diagnostics
print("=== Convergence Diagnostics ===")
r_hat_values = az.r_hat(trace)
print("R-hat values:", r_hat_values)

ess_bulk = az.ess(trace, method='bulk')
ess_tail = az.ess(trace, method='tail')
print("Effective sample size (bulk):", ess_bulk)
print("Effective sample size (tail):", ess_tail)

mcse_values = az.mcse(trace)
print("Monte Carlo standard errors:", mcse_values)

# Comprehensive summary
summary_stats = az.summary(trace)
print("\n=== Posterior Summary ===")
print(summary_stats)

# Visual diagnostics
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Trace plots
az.plot_trace(trace, ax=axes[0])

# Rank plots  
az.plot_rank(trace, ax=axes[1, 0])

# Autocorrelation
az.plot_autocorr(trace, max_lag=50, ax=axes[1, 1])

plt.tight_layout()
plt.show()

Model Comparison Workflow

# Multiple models for comparison
models = {}
traces = {}

# Model 1: Simple linear
with pm.Model() as model1:
    alpha1 = pm.Normal('alpha', mu=0, sigma=10)
    beta1 = pm.Normal('beta', mu=0, sigma=10)
    sigma1 = pm.HalfNormal('sigma', sigma=1)
    
    mu1 = alpha1 + beta1 * x_data
    y_obs1 = pm.Normal('y_obs', mu=mu1, sigma=sigma1, observed=y_data)
    
    trace1 = pm.sample(1000, tune=1000, return_inferencedata=True)

models['Linear'] = model1
traces['Linear'] = trace1

# Model 2: Quadratic
with pm.Model() as model2:
    alpha2 = pm.Normal('alpha', mu=0, sigma=10)
    beta1_2 = pm.Normal('beta1', mu=0, sigma=10)
    beta2_2 = pm.Normal('beta2', mu=0, sigma=10)
    sigma2 = pm.HalfNormal('sigma', sigma=1)
    
    mu2 = alpha2 + beta1_2 * x_data + beta2_2 * x_data**2
    y_obs2 = pm.Normal('y_obs', mu=mu2, sigma=sigma2, observed=y_data)
    
    trace2 = pm.sample(1000, tune=1000, return_inferencedata=True)

models['Quadratic'] = model2  
traces['Quadratic'] = trace2

# Model 3: Robust (Student's t)
with pm.Model() as model3:
    alpha3 = pm.Normal('alpha', mu=0, sigma=10)
    beta3 = pm.Normal('beta', mu=0, sigma=10)
    sigma3 = pm.HalfNormal('sigma', sigma=1)
    nu = pm.Gamma('nu', alpha=2, beta=0.1)
    
    mu3 = alpha3 + beta3 * x_data
    y_obs3 = pm.StudentT('y_obs', nu=nu, mu=mu3, sigma=sigma3, observed=y_data)
    
    trace3 = pm.sample(1000, tune=1000, return_inferencedata=True)

models['Robust'] = model3
traces['Robust'] = trace3

# Compute information criteria
waic_results = {}
loo_results = {}

for name, trace in traces.items():
    waic_results[name] = az.waic(trace)
    loo_results[name] = az.loo(trace)

# Model comparison
comparison_waic = az.compare(traces, ic='waic')
comparison_loo = az.compare(traces, ic='loo')

print("=== Model Comparison (WAIC) ===")
print(comparison_waic)

print("\n=== Model Comparison (LOO) ===")
print(comparison_loo)

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

az.plot_compare(comparison_waic, ax=axes[0])
axes[0].set_title('WAIC Comparison')

az.plot_compare(comparison_loo, ax=axes[1])
axes[1].set_title('LOO Comparison')

plt.tight_layout()
plt.show()

# Check LOO reliability
for name, loo_result in loo_results.items():
    k_values = loo_result.pareto_k.values.flatten()
    n_high_k = np.sum(k_values > 0.7)
    print(f"{name}: {n_high_k} observations with high Pareto k (> 0.7)")

Posterior Predictive Checking

# Generate posterior predictive samples
with models['Linear']:  # Use best model from comparison
    ppc = pm.sample_posterior_predictive(traces['Linear'], samples=100)

# Add posterior predictive to InferenceData
traces['Linear'].extend(az.from_pymc3(posterior_predictive=ppc))

# Posterior predictive checks
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Basic PPC plot
az.plot_ppc(traces['Linear'], ax=axes[0, 0], kind='kde')
axes[0, 0].set_title('Posterior Predictive Check (KDE)')

# Cumulative PPC
az.plot_ppc(traces['Linear'], ax=axes[0, 1], kind='cumulative')
axes[0, 1].set_title('Cumulative PPC')

# LOO-PIT for calibration
az.plot_loo_pit(traces['Linear'], ax=axes[1, 0])
axes[1, 0].set_title('LOO-PIT Calibration')

# Custom PPC statistics
def ppc_statistics(y_obs, y_pred):
    """Custom statistics for PPC."""
    return {
        'mean': np.mean(y_pred, axis=1),
        'std': np.std(y_pred, axis=1),  
        'min': np.min(y_pred, axis=1),
        'max': np.max(y_pred, axis=1)
    }

# Compute statistics
obs_stats = ppc_statistics(y_data, y_data.reshape(1, -1))
pred_stats = ppc_statistics(y_data, ppc['y_obs'])

# Plot statistics comparison
statistics = ['mean', 'std', 'min', 'max']
obs_values = [obs_stats[stat][0] for stat in statistics]
pred_means = [np.mean(pred_stats[stat]) for stat in statistics]
pred_stds = [np.std(pred_stats[stat]) for stat in statistics]

x_pos = np.arange(len(statistics))
axes[1, 1].bar(x_pos - 0.2, obs_values, 0.4, label='Observed', alpha=0.7)
axes[1, 1].errorbar(x_pos + 0.2, pred_means, yerr=pred_stds, 
                   fmt='o', label='Predicted', capsize=5)
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(statistics)
axes[1, 1].legend()
axes[1, 1].set_title('Summary Statistics Comparison')

plt.tight_layout()
plt.show()

Advanced Visualization

# Multi-parameter visualization
with pm.Model() as multivariate_model:
    # Correlated parameters
    theta = pm.MvNormal('theta', 
                       mu=np.zeros(4), 
                       cov=np.eye(4), 
                       shape=4)
    
    # Transform for identifiability
    alpha = pm.Deterministic('alpha', theta[0])
    beta = pm.Deterministic('beta', theta[1:])
    
    # Model prediction
    mu = alpha + pm.math.dot(beta, X_multi.T)
    y_obs = pm.Normal('y_obs', mu=mu, sigma=0.5, observed=y_multi)
    
    trace_mv = pm.sample(1000, tune=1000, return_inferencedata=True)

# Comprehensive visualization suite
fig = plt.figure(figsize=(16, 12))

# Trace plots
axes_trace = fig.add_subplot(3, 3, (1, 2))
az.plot_trace(trace_mv, var_names=['alpha'], ax=axes_trace)

# Posterior distributions
axes_post = fig.add_subplot(3, 3, 3)
az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes_post)

# Forest plot for coefficients
axes_forest = fig.add_subplot(3, 3, (4, 5))
az.plot_forest(trace_mv, var_names=['beta'], ax=axes_forest)

# Pairwise relationships
axes_pair = fig.add_subplot(3, 3, 6)
az.plot_pair(trace_mv, var_names=['alpha', 'beta'], 
            coords={'beta_dim_0': slice(0, 2)}, ax=axes_pair)

# Energy diagnostic
axes_energy = fig.add_subplot(3, 3, 7)
az.plot_energy(trace_mv, ax=axes_energy)

# Parallel coordinates
axes_parallel = fig.add_subplot(3, 3, 8)
az.plot_parallel(trace_mv, var_names=['alpha', 'beta'], ax=axes_parallel)

# Rank plot
axes_rank = fig.add_subplot(3, 3, 9)
az.plot_rank(trace_mv, var_names=['alpha'], ax=axes_rank)

plt.tight_layout()
plt.show()

Custom Diagnostic Workflow

# Custom convergence assessment
def comprehensive_diagnostics(trace, var_names=None):
    """Comprehensive diagnostic assessment."""
    
    if var_names is None:
        var_names = list(trace.posterior.data_vars)
    
    diagnostics = {}
    
    for var in var_names:
        var_diagnostics = {}
        
        # Basic convergence metrics
        var_diagnostics['r_hat'] = float(az.r_hat(trace, var_names=[var])[var])
        var_diagnostics['ess_bulk'] = float(az.ess(trace, var_names=[var], method='bulk')[var])
        var_diagnostics['ess_tail'] = float(az.ess(trace, var_names=[var], method='tail')[var])
        var_diagnostics['mcse_mean'] = float(az.mcse(trace, var_names=[var], method='mean')[var])
        
        # Effective sample size ratios
        n_samples = trace.posterior[var].size
        var_diagnostics['ess_bulk_ratio'] = var_diagnostics['ess_bulk'] / n_samples
        var_diagnostics['ess_tail_ratio'] = var_diagnostics['ess_tail'] / n_samples
        
        # Convergence flags
        var_diagnostics['converged'] = (
            var_diagnostics['r_hat'] < 1.01 and 
            var_diagnostics['ess_bulk'] > 400 and
            var_diagnostics['ess_tail'] > 400
        )
        
        diagnostics[var] = var_diagnostics
    
    return diagnostics

# Run diagnostics
diag_results = comprehensive_diagnostics(trace_mv)

print("=== Comprehensive Diagnostics ===")
for var, diag in diag_results.items():
    status = "✓ PASS" if diag['converged'] else "✗ FAIL"
    print(f"\n{var} {status}")
    print(f"  R-hat: {diag['r_hat']:.4f}")
    print(f"  ESS bulk: {diag['ess_bulk']:.0f} ({diag['ess_bulk_ratio']:.2f})")
    print(f"  ESS tail: {diag['ess_tail']:.0f} ({diag['ess_tail_ratio']:.2f})")
    print(f"  MCSE mean: {diag['mcse_mean']:.4f}")

# Summary convergence status
all_converged = all(diag['converged'] for diag in diag_results.values())
print(f"\nOverall convergence: {'✓ PASS' if all_converged else '✗ FAIL'}")

if not all_converged:
    print("\nRecommendations:")
    print("- Increase number of samples")
    print("- Check model parameterization")
    print("- Consider different step size or sampler settings")

Publication-Ready Plots

# Create publication-quality figures
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 12,
    'figure.titlesize': 18
})

# Multi-panel figure for publication
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Bayesian Linear Regression Analysis', fontsize=18, y=0.98)

# Panel A: Posterior distributions
az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes[0, 0], 
                 hdi_prob=0.95, point_estimate='mean')
axes[0, 0].set_title('A. Intercept Posterior')

# Panel B: Coefficient forest plot  
az.plot_forest(trace_mv, var_names=['beta'], ax=axes[0, 1],
              credible_interval=0.95, quartiles=False)
axes[0, 1].set_title('B. Coefficient Estimates')

# Panel C: Model comparison
az.plot_compare(comparison_waic, ax=axes[0, 2])
axes[0, 2].set_title('C. Model Comparison (WAIC)')

# Panel D: Posterior predictive check
az.plot_ppc(traces['Linear'], ax=axes[1, 0], kind='kde', 
           alpha=0.1, num_pp_samples=50)
axes[1, 0].set_title('D. Posterior Predictive Check')

# Panel E: Residual analysis (custom)
# Extract posterior mean predictions
post_pred = ppc['y_obs'].mean(axis=0)
residuals = y_data - post_pred

axes[1, 1].scatter(post_pred, residuals, alpha=0.6)
axes[1, 1].axhline(y=0, color='red', linestyle='--')
axes[1, 1].set_xlabel('Fitted Values')
axes[1, 1].set_ylabel('Residuals')
axes[1, 1].set_title('E. Residual Analysis')

# Panel F: Convergence diagnostics summary
convergence_summary = pd.DataFrame(diag_results).T[['r_hat', 'ess_bulk_ratio']]
convergence_summary.plot(kind='bar', ax=axes[1, 2])
axes[1, 2].set_title('F. Convergence Summary')
axes[1, 2].set_ylabel('Diagnostic Value')
axes[1, 2].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig('bayesian_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

Install with Tessl CLI

npx tessl i tessl/pypi-pymc3

docs

data-handling.md

distributions.md

gaussian-processes.md

glm.md

index.md

math-functions.md

modeling.md

sampling.md

stats-plots.md

step-methods.md

variational.md

tile.json