Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with Theano
68
PyMC3 integrates tightly with ArviZ for comprehensive Bayesian analysis, model diagnostics, and publication-quality visualizations. The stats and plots modules delegate to ArviZ while providing PyMC3-specific functionality and convenient aliases for common workflows.
Functions for assessing MCMC convergence and sample quality through pymc3.stats.*.
def r_hat(trace, var_names=None, method='rank'):
"""
Compute R-hat convergence diagnostic.
Measures between-chain and within-chain variance to assess
convergence across multiple MCMC chains. Values close to 1.0
indicate good convergence.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- var_names: list, variables to analyze (all if None)
- method: str, computation method ('rank', 'split', 'folded')
Returns:
- dict or array: R-hat values by variable
Interpretation:
- R_hat < 1.01: Excellent convergence
- R_hat < 1.1: Good convergence
- R_hat > 1.1: Poor convergence, need more samples
"""
def ess(trace, var_names=None, method='bulk'):
"""
Compute effective sample size.
Estimates the number of independent samples, accounting for
autocorrelation in MCMC chains. Higher values indicate better
mixing and more efficient sampling.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- var_names: list, variables to analyze
- method: str, ESS type ('bulk', 'tail', 'mean', 'sd', 'quantile')
Returns:
- dict or array: effective sample sizes
Guidelines:
- ESS > 400: Generally sufficient for posterior inference
- ESS > 100: Minimum for reasonable estimates
- ESS < 100: Increase sampling or improve model
"""
def mcse(trace, var_names=None, method='mean', prob=None):
"""
Monte Carlo standard error of estimates.
Measures uncertainty in posterior estimates due to finite
sampling, helping determine if more samples are needed.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- var_names: list, variables to analyze
- method: str, estimate type ('mean', 'sd', 'quantile')
- prob: float, probability for quantile MCSE
Returns:
- dict: MCSE values by variable
"""
def geweke(trace, var_names=None, first=0.1, last=0.5, intervals=20):
"""
Geweke convergence diagnostic.
Compares means from early and late portions of chains to
assess within-chain convergence and stationarity.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- var_names: list, variables to test
- first: float, fraction for early portion
- last: float, fraction for late portion
- intervals: int, number of test intervals
Returns:
- dict: Geweke statistics by variable
"""Information criteria and cross-validation for Bayesian model selection.
def compare(models, ic='waic', method='stacking', b_samples=1000,
alpha=1, seed=None, round_to=2):
"""
Compare multiple models using information criteria.
Ranks models by predictive performance using WAIC, LOO-CV,
or other criteria, with model weights and standard errors.
Parameters:
- models: dict, mapping model names to InferenceData objects
- ic: str, information criterion ('waic', 'loo')
- method: str, weighting method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')
- b_samples: int, samples for Bootstrap weighting
- alpha: float, concentration parameter for pseudo-BMA
- seed: int, random seed for reproducibility
- round_to: int, decimal places for results
Returns:
- DataFrame: model comparison results with ranks and weights
Columns:
- rank: model ranking (0 = best)
- elpd_*: expected log pointwise predictive density
- p_*: effective number of parameters
- d_*: difference from best model
- weight: model averaging weights
- se: standard error of differences
- dse: standard error of difference from best
"""
def waic(trace, model=None, pointwise=False, scale='deviance'):
"""
Watanabe-Akaike Information Criterion.
Estimates out-of-sample predictive performance using
within-sample log-likelihood with penalty for overfitting.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- model: Model, model context (current if None)
- pointwise: bool, return pointwise WAIC values
- scale: str, return scale ('deviance' or 'log')
Returns:
- ELPDData: WAIC results with components and diagnostics
Components:
- elpd_waic: expected log pointwise predictive density
- p_waic: effective number of parameters
- waic: -2 * elpd_waic (lower is better)
- se: standard error of WAIC
"""
def loo(trace, model=None, pointwise=False, reff=None, scale='deviance'):
"""
Pareto Smoothed Importance Sampling Leave-One-Out Cross-Validation.
Estimates out-of-sample performance using leave-one-out
cross-validation approximated by importance sampling.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- model: Model, model context
- pointwise: bool, return pointwise LOO values
- reff: array, relative effective sample sizes
- scale: str, return scale ('deviance' or 'log')
Returns:
- ELPDData: LOO-CV results with Pareto diagnostics
Diagnostics:
- Pareto k < 0.5: Good approximation
- Pareto k < 0.7: Okay approximation
- Pareto k > 0.7: Poor approximation, use exact CV
"""
def loo_pit(idata, y=None, y_hat=None, log_weights=None):
"""
Leave-one-out probability integral transform.
Calibration check for posterior predictive distributions
using LOO-PIT values that should be uniform if well-calibrated.
Parameters:
- idata: InferenceData, posterior and predictions
- y: array, observed data (from idata if None)
- y_hat: array, posterior predictive samples
- log_weights: array, importance sampling weights
Returns:
- array: LOO-PIT values for calibration assessment
"""Posterior summary and descriptive statistics.
def summary(trace, var_names=None, stat_funcs=None, extend=True,
credible_interval=0.94, round_to=2, kind='stats'):
"""
Comprehensive posterior summary statistics.
Provides means, standard deviations, credible intervals,
and convergence diagnostics for all model parameters.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- var_names: list, variables to summarize (all if None)
- stat_funcs: dict, custom summary functions
- extend: bool, include convergence diagnostics
- credible_interval: float, credible interval width
- round_to: int, decimal places
- kind: str, summary type ('stats', 'diagnostics')
Returns:
- DataFrame: comprehensive parameter summary
Columns:
- mean: posterior mean
- sd: posterior standard deviation
- hdi_3%/hdi_97%: highest density interval bounds
- mcse_mean: MCSE of mean
- mcse_sd: MCSE of standard deviation
- ess_bulk/ess_tail: effective sample sizes
- r_hat: R-hat convergence diagnostic
"""
def describe(trace, var_names=None, include_ci=True, ci_prob=0.94):
"""
Descriptive statistics for posterior distributions.
Parameters:
- trace: InferenceData or MultiTrace, posterior samples
- var_names: list, variables to describe
- include_ci: bool, include credible intervals
- ci_prob: float, credible interval probability
Returns:
- DataFrame: descriptive statistics
"""
def quantiles(x, qlist=(0.025, 0.25, 0.5, 0.75, 0.975)):
"""
Compute quantiles of posterior samples.
Parameters:
- x: array, samples
- qlist: tuple, quantile probabilities
Returns:
- dict: quantile values
"""
def hdi(x, credible_interval=0.94, circular=False):
"""
Highest Density Interval (HDI).
Computes the shortest interval containing specified
probability mass of the posterior distribution.
Parameters:
- x: array, posterior samples
- credible_interval: float, interval probability
- circular: bool, circular data (angles)
Returns:
- array: [lower_bound, upper_bound]
"""Advanced posterior analysis and derived quantities.
def autocorr(trace, var_names=None, max_lag=100):
"""
Autocorrelation function of MCMC chains.
Measures correlation between samples at different lags
to assess mixing and effective sample size.
Parameters:
- trace: InferenceData or MultiTrace, samples
- var_names: list, variables to analyze
- max_lag: int, maximum lag to compute
Returns:
- dict: autocorrelation functions by variable
"""
def make_ufunc(func, nin=1, nout=1, **kwargs):
"""
Create universal function for posterior analysis.
Converts regular functions into universal functions
that work efficiently on posterior sample arrays.
Parameters:
- func: callable, function to convert
- nin: int, number of inputs
- nout: int, number of outputs
- kwargs: additional ufunc arguments
Returns:
- ufunc: universal function
"""
def from_dict(posterior_dict, coords=None, dims=None):
"""
Create InferenceData from dictionary of arrays.
Parameters:
- posterior_dict: dict, posterior samples by variable
- coords: dict, coordinate values
- dims: dict, dimension names by variable
Returns:
- InferenceData: formatted inference data
"""Comprehensive visualization capabilities through ArviZ integration via pymc3.plots.*.
def plot_trace(trace, var_names=None, coords=None, divergences='auto',
figsize=None, rug=False, lines=None, compact=True,
combined=False, legend=False, plot_kwargs=None,
fill_kwargs=None, rug_kwargs=None, **kwargs):
"""
Trace plots showing MCMC sampling paths and marginal distributions.
Essential diagnostic plot combining time series of samples
with marginal posterior distributions for visual convergence assessment.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to plot
- coords: dict, coordinate slices for multidimensional variables
- divergences: str or bool, highlight divergent transitions
- figsize: tuple, figure size
- rug: bool, add rug plot to marginals
- lines: dict, reference lines to overlay
- compact: bool, compact layout
- combined: bool, combine all chains
- legend: bool, show chain legend
- plot_kwargs: dict, line plot arguments
- fill_kwargs: dict, density fill arguments
- rug_kwargs: dict, rug plot arguments
Returns:
- matplotlib axes: plot axes array
"""
def plot_posterior(trace, var_names=None, coords=None, figsize=None,
textsize=None, hdi_prob=0.94, multimodal=False,
skipna=False, ref_val=None, rope=None, point_estimate='mean',
round_to=2, credible_interval=None, **kwargs):
"""
Posterior distribution plots with summary statistics.
Shows marginal posterior distributions with credible intervals,
point estimates, and optional reference values or ROPE.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to plot
- coords: dict, coordinate selections
- figsize: tuple, figure size
- textsize: float, text size for annotations
- hdi_prob: float, HDI probability
- multimodal: bool, detect and handle multimodal distributions
- skipna: bool, skip missing values
- ref_val: dict, reference values by variable
- rope: dict, region of practical equivalence bounds
- point_estimate: str, point estimate type ('mean', 'median', 'mode')
- round_to: int, decimal places for annotations
- credible_interval: float, deprecated alias for hdi_prob
Returns:
- matplotlib axes: plot axes
"""
def plot_forest(trace, var_names=None, coords=None, figsize=None,
textsize=None, ropestyle='top', ropes=None, credible_interval=0.94,
quartiles=True, r_hat=True, ess=True, combined=False,
colors='cycle', **kwargs):
"""
Forest plot showing parameter estimates with uncertainty intervals.
Horizontal plot displaying point estimates and credible intervals
for multiple parameters, useful for coefficient comparison.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to include
- coords: dict, coordinate selections
- figsize: tuple, figure size
- textsize: float, text size
- ropestyle: str, ROPE display style ('top', 'bottom', None)
- ropes: dict, ROPE bounds by variable
- credible_interval: float, interval probability
- quartiles: bool, show quartile markers
- r_hat: bool, show R-hat values
- ess: bool, show effective sample size
- combined: bool, combine chains before plotting
- colors: str or list, color specification
Returns:
- matplotlib axes: plot axes
"""
def plot_autocorr(trace, var_names=None, coords=None, figsize=None,
textsize=None, max_lag=100, combined=False, **kwargs):
"""
Autocorrelation plots for assessing chain mixing.
Shows autocorrelation function to diagnose slow mixing
and estimate effective sample sizes visually.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to plot
- coords: dict, coordinate selections
- figsize: tuple, figure size
- textsize: float, text size
- max_lag: int, maximum lag to plot
- combined: bool, combine chains
Returns:
- matplotlib axes: plot axes
"""
def plot_rank(trace, var_names=None, coords=None, figsize=None,
bins=20, kind='bars', **kwargs):
"""
Rank plots for MCMC diagnostics.
Shows rank statistics across chains to identify mixing
problems and between-chain differences.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to plot
- coords: dict, coordinate selections
- figsize: tuple, figure size
- bins: int, number of rank bins
- kind: str, plot type ('bars', 'vlines')
Returns:
- matplotlib axes: plot axes
"""
def plot_energy(trace, figsize=None, **kwargs):
"""
Energy plot for HMC/NUTS diagnostics.
Compares energy distributions between tuning and sampling
phases to identify potential sampling problems.
Parameters:
- trace: InferenceData, posterior samples with energy info
- figsize: tuple, figure size
Returns:
- matplotlib axes: plot axes
"""
def plot_pair(trace, var_names=None, coords=None, figsize=None,
textsize=None, kind='scatter', gridsize='auto',
colorbar=True, divergences=False, **kwargs):
"""
Pairwise parameter plots showing correlations and structure.
Matrix of bivariate plots revealing posterior correlations,
multimodality, and geometric structure.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to include
- coords: dict, coordinate selections
- figsize: tuple, figure size
- textsize: float, text size
- kind: str, plot type ('scatter', 'kde', 'hexbin')
- gridsize: int or 'auto', grid resolution for kde/hexbin
- colorbar: bool, show colorbar for density plots
- divergences: bool, highlight divergent samples
Returns:
- matplotlib axes: plot axes matrix
"""
def plot_parallel(trace, var_names=None, coords=None, figsize=None,
colornd='k', colord='r', shadend=0.025, **kwargs):
"""
Parallel coordinates plot for high-dimensional visualization.
Shows sample paths across multiple parameters to identify
correlations and outliers in high-dimensional posteriors.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to include
- coords: dict, coordinate selections
- figsize: tuple, figure size
- colornd: color for non-divergent samples
- colord: color for divergent samples
- shadend: float, transparency for non-divergent samples
Returns:
- matplotlib axes: plot axes
"""
def plot_violin(trace, var_names=None, coords=None, figsize=None,
textsize=None, credible_interval=0.94, quartiles=True,
rug=False, **kwargs):
"""
Violin plots showing posterior distribution shapes.
Kernel density estimates with optional quartiles and
credible intervals for comparing parameter distributions.
Parameters:
- trace: InferenceData, posterior samples
- var_names: list, variables to plot
- coords: dict, coordinate selections
- figsize: tuple, figure size
- textsize: float, text size
- credible_interval: float, interval to mark
- quartiles: bool, show quartile lines
- rug: bool, add rug plot
Returns:
- matplotlib axes: plot axes
"""
def plot_kde(values, values2=None, cumulative=False, rug=False,
label=None, bw='scott', adaptive=False, extend=True,
gridsize=None, clip=None, alpha=0.7, **kwargs):
"""
Kernel density estimation plots.
Smooth density estimates for continuous distributions
with options for cumulative plots and comparisons.
Parameters:
- values: array, samples to plot
- values2: array, optional second sample for comparison
- cumulative: bool, plot cumulative density
- rug: bool, add rug plot
- label: str, plot label
- bw: str or float, bandwidth selection method
- adaptive: bool, use adaptive bandwidth
- extend: bool, extend domain beyond data range
- gridsize: int, evaluation grid size
- clip: tuple, domain bounds
- alpha: float, transparency
Returns:
- matplotlib axes: plot axes
"""Functions for model validation through posterior predictive distributions.
def plot_ppc(trace, kind='kde', alpha=0.05, figsize=None, textsize=None,
data_pairs=None, var_names=None, coords=None, flatten=None,
flatten_pp=None, num_pp_samples=100, random_seed=None,
jitter=None, mean=True, observed=True, **kwargs):
"""
Posterior predictive check plots.
Compares observed data with posterior predictive samples
to assess model fit and identify systematic deviations.
Parameters:
- trace: InferenceData, with posterior_predictive group
- kind: str, plot type ('kde', 'cumulative', 'scatter')
- alpha: float, transparency for predictive samples
- figsize: tuple, figure size
- textsize: float, text size
- data_pairs: dict, observed data by variable name
- var_names: list, variables to plot
- coords: dict, coordinate selections
- flatten: list, dimensions to flatten
- flatten_pp: list, posterior predictive dimensions to flatten
- num_pp_samples: int, number of predictive samples to show
- random_seed: int, random seed for sample selection
- jitter: float, jitter amount for discrete data
- mean: bool, show predictive mean
- observed: bool, show observed data
Returns:
- matplotlib axes: plot axes
"""
def plot_loo_pit(idata, y=None, y_hat=None, log_weights=None,
ecdf=False, ecdf_fill=True, use_hdi=True,
credible_interval=0.99, figsize=None, **kwargs):
"""
Leave-one-out probability integral transform plots.
Diagnostic plots for posterior predictive calibration
using LOO-PIT values that should be uniform if well-calibrated.
Parameters:
- idata: InferenceData, inference results
- y: array, observed values
- y_hat: array, posterior predictive samples
- log_weights: array, importance weights
- ecdf: bool, overlay empirical CDF
- ecdf_fill: bool, fill ECDF confidence band
- use_hdi: bool, use HDI for confidence bands
- credible_interval: float, confidence level
- figsize: tuple, figure size
Returns:
- matplotlib axes: plot axes
"""Visualization for comparing multiple models.
def plot_compare(comp_df, insample_dev=True, plot_ic_diff=True,
order_by_rank=True, figsize=None, textsize=None, **kwargs):
"""
Model comparison plot showing information criteria.
Visual comparison of models using WAIC/LOO with
standard errors and ranking information.
Parameters:
- comp_df: DataFrame, results from az.compare()
- insample_dev: bool, plot in-sample deviance
- plot_ic_diff: bool, plot differences from best model
- order_by_rank: bool, order models by rank
- figsize: tuple, figure size
- textsize: float, text size
Returns:
- matplotlib axes: plot axes
"""
def plot_elpd(comp_df, xlabels=False, figsize=None, textsize=None,
color='C0', **kwargs):
"""
Expected log predictive density comparison plot.
Parameters:
- comp_df: DataFrame, comparison results
- xlabels: bool, show x-axis labels
- figsize: tuple, figure size
- textsize: float, text size
- color: color specification
Returns:
- matplotlib axes: plot axes
"""
def plot_khat(khats, bins=None, figsize=None, ax=None, **kwargs):
"""
Pareto k diagnostic plot for LOO reliability.
Shows distribution of Pareto k values to assess
reliability of LOO approximation.
Parameters:
- khats: array, Pareto k values from loo()
- bins: int, histogram bins
- figsize: tuple, figure size
- ax: matplotlib axes, existing axes
Returns:
- matplotlib axes: plot axes
"""import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
# Example model and sampling
with pm.Model() as diagnostic_model:
mu = pm.Normal('mu', mu=0, sigma=10)
sigma = pm.HalfNormal('sigma', sigma=5)
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)
# Sample with multiple chains for diagnostics
trace = pm.sample(1000, tune=1000, chains=4,
target_accept=0.95, return_inferencedata=True)
# Convergence diagnostics
print("=== Convergence Diagnostics ===")
r_hat_values = az.r_hat(trace)
print("R-hat values:", r_hat_values)
ess_bulk = az.ess(trace, method='bulk')
ess_tail = az.ess(trace, method='tail')
print("Effective sample size (bulk):", ess_bulk)
print("Effective sample size (tail):", ess_tail)
mcse_values = az.mcse(trace)
print("Monte Carlo standard errors:", mcse_values)
# Comprehensive summary
summary_stats = az.summary(trace)
print("\n=== Posterior Summary ===")
print(summary_stats)
# Visual diagnostics
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Trace plots
az.plot_trace(trace, ax=axes[0])
# Rank plots
az.plot_rank(trace, ax=axes[1, 0])
# Autocorrelation
az.plot_autocorr(trace, max_lag=50, ax=axes[1, 1])
plt.tight_layout()
plt.show()# Multiple models for comparison
models = {}
traces = {}
# Model 1: Simple linear
with pm.Model() as model1:
alpha1 = pm.Normal('alpha', mu=0, sigma=10)
beta1 = pm.Normal('beta', mu=0, sigma=10)
sigma1 = pm.HalfNormal('sigma', sigma=1)
mu1 = alpha1 + beta1 * x_data
y_obs1 = pm.Normal('y_obs', mu=mu1, sigma=sigma1, observed=y_data)
trace1 = pm.sample(1000, tune=1000, return_inferencedata=True)
models['Linear'] = model1
traces['Linear'] = trace1
# Model 2: Quadratic
with pm.Model() as model2:
alpha2 = pm.Normal('alpha', mu=0, sigma=10)
beta1_2 = pm.Normal('beta1', mu=0, sigma=10)
beta2_2 = pm.Normal('beta2', mu=0, sigma=10)
sigma2 = pm.HalfNormal('sigma', sigma=1)
mu2 = alpha2 + beta1_2 * x_data + beta2_2 * x_data**2
y_obs2 = pm.Normal('y_obs', mu=mu2, sigma=sigma2, observed=y_data)
trace2 = pm.sample(1000, tune=1000, return_inferencedata=True)
models['Quadratic'] = model2
traces['Quadratic'] = trace2
# Model 3: Robust (Student's t)
with pm.Model() as model3:
alpha3 = pm.Normal('alpha', mu=0, sigma=10)
beta3 = pm.Normal('beta', mu=0, sigma=10)
sigma3 = pm.HalfNormal('sigma', sigma=1)
nu = pm.Gamma('nu', alpha=2, beta=0.1)
mu3 = alpha3 + beta3 * x_data
y_obs3 = pm.StudentT('y_obs', nu=nu, mu=mu3, sigma=sigma3, observed=y_data)
trace3 = pm.sample(1000, tune=1000, return_inferencedata=True)
models['Robust'] = model3
traces['Robust'] = trace3
# Compute information criteria
waic_results = {}
loo_results = {}
for name, trace in traces.items():
waic_results[name] = az.waic(trace)
loo_results[name] = az.loo(trace)
# Model comparison
comparison_waic = az.compare(traces, ic='waic')
comparison_loo = az.compare(traces, ic='loo')
print("=== Model Comparison (WAIC) ===")
print(comparison_waic)
print("\n=== Model Comparison (LOO) ===")
print(comparison_loo)
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
az.plot_compare(comparison_waic, ax=axes[0])
axes[0].set_title('WAIC Comparison')
az.plot_compare(comparison_loo, ax=axes[1])
axes[1].set_title('LOO Comparison')
plt.tight_layout()
plt.show()
# Check LOO reliability
for name, loo_result in loo_results.items():
k_values = loo_result.pareto_k.values.flatten()
n_high_k = np.sum(k_values > 0.7)
print(f"{name}: {n_high_k} observations with high Pareto k (> 0.7)")# Generate posterior predictive samples
with models['Linear']: # Use best model from comparison
ppc = pm.sample_posterior_predictive(traces['Linear'], samples=100)
# Add posterior predictive to InferenceData
traces['Linear'].extend(az.from_pymc3(posterior_predictive=ppc))
# Posterior predictive checks
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Basic PPC plot
az.plot_ppc(traces['Linear'], ax=axes[0, 0], kind='kde')
axes[0, 0].set_title('Posterior Predictive Check (KDE)')
# Cumulative PPC
az.plot_ppc(traces['Linear'], ax=axes[0, 1], kind='cumulative')
axes[0, 1].set_title('Cumulative PPC')
# LOO-PIT for calibration
az.plot_loo_pit(traces['Linear'], ax=axes[1, 0])
axes[1, 0].set_title('LOO-PIT Calibration')
# Custom PPC statistics
def ppc_statistics(y_obs, y_pred):
"""Custom statistics for PPC."""
return {
'mean': np.mean(y_pred, axis=1),
'std': np.std(y_pred, axis=1),
'min': np.min(y_pred, axis=1),
'max': np.max(y_pred, axis=1)
}
# Compute statistics
obs_stats = ppc_statistics(y_data, y_data.reshape(1, -1))
pred_stats = ppc_statistics(y_data, ppc['y_obs'])
# Plot statistics comparison
statistics = ['mean', 'std', 'min', 'max']
obs_values = [obs_stats[stat][0] for stat in statistics]
pred_means = [np.mean(pred_stats[stat]) for stat in statistics]
pred_stds = [np.std(pred_stats[stat]) for stat in statistics]
x_pos = np.arange(len(statistics))
axes[1, 1].bar(x_pos - 0.2, obs_values, 0.4, label='Observed', alpha=0.7)
axes[1, 1].errorbar(x_pos + 0.2, pred_means, yerr=pred_stds,
fmt='o', label='Predicted', capsize=5)
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(statistics)
axes[1, 1].legend()
axes[1, 1].set_title('Summary Statistics Comparison')
plt.tight_layout()
plt.show()# Multi-parameter visualization
with pm.Model() as multivariate_model:
# Correlated parameters
theta = pm.MvNormal('theta',
mu=np.zeros(4),
cov=np.eye(4),
shape=4)
# Transform for identifiability
alpha = pm.Deterministic('alpha', theta[0])
beta = pm.Deterministic('beta', theta[1:])
# Model prediction
mu = alpha + pm.math.dot(beta, X_multi.T)
y_obs = pm.Normal('y_obs', mu=mu, sigma=0.5, observed=y_multi)
trace_mv = pm.sample(1000, tune=1000, return_inferencedata=True)
# Comprehensive visualization suite
fig = plt.figure(figsize=(16, 12))
# Trace plots
axes_trace = fig.add_subplot(3, 3, (1, 2))
az.plot_trace(trace_mv, var_names=['alpha'], ax=axes_trace)
# Posterior distributions
axes_post = fig.add_subplot(3, 3, 3)
az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes_post)
# Forest plot for coefficients
axes_forest = fig.add_subplot(3, 3, (4, 5))
az.plot_forest(trace_mv, var_names=['beta'], ax=axes_forest)
# Pairwise relationships
axes_pair = fig.add_subplot(3, 3, 6)
az.plot_pair(trace_mv, var_names=['alpha', 'beta'],
coords={'beta_dim_0': slice(0, 2)}, ax=axes_pair)
# Energy diagnostic
axes_energy = fig.add_subplot(3, 3, 7)
az.plot_energy(trace_mv, ax=axes_energy)
# Parallel coordinates
axes_parallel = fig.add_subplot(3, 3, 8)
az.plot_parallel(trace_mv, var_names=['alpha', 'beta'], ax=axes_parallel)
# Rank plot
axes_rank = fig.add_subplot(3, 3, 9)
az.plot_rank(trace_mv, var_names=['alpha'], ax=axes_rank)
plt.tight_layout()
plt.show()# Custom convergence assessment
def comprehensive_diagnostics(trace, var_names=None):
"""Comprehensive diagnostic assessment."""
if var_names is None:
var_names = list(trace.posterior.data_vars)
diagnostics = {}
for var in var_names:
var_diagnostics = {}
# Basic convergence metrics
var_diagnostics['r_hat'] = float(az.r_hat(trace, var_names=[var])[var])
var_diagnostics['ess_bulk'] = float(az.ess(trace, var_names=[var], method='bulk')[var])
var_diagnostics['ess_tail'] = float(az.ess(trace, var_names=[var], method='tail')[var])
var_diagnostics['mcse_mean'] = float(az.mcse(trace, var_names=[var], method='mean')[var])
# Effective sample size ratios
n_samples = trace.posterior[var].size
var_diagnostics['ess_bulk_ratio'] = var_diagnostics['ess_bulk'] / n_samples
var_diagnostics['ess_tail_ratio'] = var_diagnostics['ess_tail'] / n_samples
# Convergence flags
var_diagnostics['converged'] = (
var_diagnostics['r_hat'] < 1.01 and
var_diagnostics['ess_bulk'] > 400 and
var_diagnostics['ess_tail'] > 400
)
diagnostics[var] = var_diagnostics
return diagnostics
# Run diagnostics
diag_results = comprehensive_diagnostics(trace_mv)
print("=== Comprehensive Diagnostics ===")
for var, diag in diag_results.items():
status = "✓ PASS" if diag['converged'] else "✗ FAIL"
print(f"\n{var} {status}")
print(f" R-hat: {diag['r_hat']:.4f}")
print(f" ESS bulk: {diag['ess_bulk']:.0f} ({diag['ess_bulk_ratio']:.2f})")
print(f" ESS tail: {diag['ess_tail']:.0f} ({diag['ess_tail_ratio']:.2f})")
print(f" MCSE mean: {diag['mcse_mean']:.4f}")
# Summary convergence status
all_converged = all(diag['converged'] for diag in diag_results.values())
print(f"\nOverall convergence: {'✓ PASS' if all_converged else '✗ FAIL'}")
if not all_converged:
print("\nRecommendations:")
print("- Increase number of samples")
print("- Check model parameterization")
print("- Consider different step size or sampler settings")# Create publication-quality figures
plt.rcParams.update({
'font.size': 12,
'axes.labelsize': 14,
'axes.titlesize': 16,
'xtick.labelsize': 11,
'ytick.labelsize': 11,
'legend.fontsize': 12,
'figure.titlesize': 18
})
# Multi-panel figure for publication
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Bayesian Linear Regression Analysis', fontsize=18, y=0.98)
# Panel A: Posterior distributions
az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes[0, 0],
hdi_prob=0.95, point_estimate='mean')
axes[0, 0].set_title('A. Intercept Posterior')
# Panel B: Coefficient forest plot
az.plot_forest(trace_mv, var_names=['beta'], ax=axes[0, 1],
credible_interval=0.95, quartiles=False)
axes[0, 1].set_title('B. Coefficient Estimates')
# Panel C: Model comparison
az.plot_compare(comparison_waic, ax=axes[0, 2])
axes[0, 2].set_title('C. Model Comparison (WAIC)')
# Panel D: Posterior predictive check
az.plot_ppc(traces['Linear'], ax=axes[1, 0], kind='kde',
alpha=0.1, num_pp_samples=50)
axes[1, 0].set_title('D. Posterior Predictive Check')
# Panel E: Residual analysis (custom)
# Extract posterior mean predictions
post_pred = ppc['y_obs'].mean(axis=0)
residuals = y_data - post_pred
axes[1, 1].scatter(post_pred, residuals, alpha=0.6)
axes[1, 1].axhline(y=0, color='red', linestyle='--')
axes[1, 1].set_xlabel('Fitted Values')
axes[1, 1].set_ylabel('Residuals')
axes[1, 1].set_title('E. Residual Analysis')
# Panel F: Convergence diagnostics summary
convergence_summary = pd.DataFrame(diag_results).T[['r_hat', 'ess_bulk_ratio']]
convergence_summary.plot(kind='bar', ax=axes[1, 2])
axes[1, 2].set_title('F. Convergence Summary')
axes[1, 2].set_ylabel('Diagnostic Value')
axes[1, 2].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.savefig('bayesian_analysis.png', dpi=300, bbox_inches='tight')
plt.show()Install with Tessl CLI
npx tessl i tessl/pypi-pymc3docs
evals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10