Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with PyTensor
—
PyMC provides comprehensive statistical functions and convergence diagnostics for Bayesian analysis, primarily through integration with ArviZ. The library offers tools for model validation, convergence assessment, and posterior analysis.
PyMC exposes key diagnostic functions from ArviZ for assessing MCMC convergence:
def rhat(data, var_names=None, method='rank', dask_kwargs=None):
"""
Compute R-hat convergence diagnostic (Gelman-Rubin statistic).
Parameters:
- data: InferenceData object or trace
- var_names (list, optional): Variables to analyze
- method (str): Method for computation ('rank', 'split', 'folded')
- dask_kwargs (dict, optional): Dask computation options
Returns:
- rhat_values: R-hat statistics for each variable
"""
import pymc as pm
# Compute R-hat for all variables
with pm.Model() as model:
# Model definition and sampling...
trace = pm.sample()
rhat_stats = pm.rhat(trace)
print("R-hat diagnostics:")
for var, rhat_val in rhat_stats.items():
print(f" {var}: {rhat_val:.4f}")
# R-hat for specific variables only
rhat_subset = pm.rhat(trace, var_names=['alpha', 'beta'])
# Check convergence (R-hat should be < 1.01)
converged = all(rhat_val < 1.01 for rhat_val in rhat_stats.values())def effective_sample_size(data, var_names=None, method='bulk',
relative=False, dask_kwargs=None):
"""
Compute effective sample size (ESS).
Parameters:
- data: InferenceData object or trace
- var_names (list, optional): Variables to analyze
- method (str): ESS method ('bulk', 'tail', 'quantile', 'mean', 'sd')
- relative (bool): Return relative ESS (ESS/N)
- dask_kwargs (dict, optional): Dask computation options
Returns:
- ess_values: Effective sample size for each variable
"""
# Bulk ESS (measures efficiency in central posterior)
bulk_ess = pm.ess(trace, method='bulk')
# Tail ESS (measures efficiency in posterior tails)
tail_ess = pm.ess(trace, method='tail')
# Relative ESS (as fraction of total samples)
rel_ess = pm.ess(trace, relative=True)
print("Effective Sample Size (bulk):")
for var, ess_val in bulk_ess.items():
print(f" {var}: {ess_val:.0f}")
# Check adequacy (ESS should be > 400 for reliable inference)
adequate_ess = all(ess_val > 400 for ess_val in bulk_ess.values())def mcse(data, var_names=None, method='mean', dask_kwargs=None):
"""
Compute Monte Carlo standard error.
Parameters:
- data: InferenceData object or trace
- var_names (list, optional): Variables to analyze
- method (str): Statistic to compute MCSE for ('mean', 'sd', 'quantile')
- dask_kwargs (dict, optional): Dask computation options
Returns:
- mcse_values: Monte Carlo standard errors
"""
# MCSE for posterior means
mcse_mean = pm.mcse(trace, method='mean')
# MCSE for posterior standard deviations
mcse_sd = pm.mcse(trace, method='sd')
# MCSE for quantiles
mcse_quantile = pm.mcse(trace, method='quantile')
print("Monte Carlo Standard Error (mean):")
for var, mcse_val in mcse_mean.items():
print(f" {var}: {mcse_val:.6f}")def loo(data, var_name=None, reff=None, scale=None, pointwise=False,
dask_kwargs=None):
"""
Compute leave-one-out (LOO) cross-validation using Pareto smoothed importance sampling.
Parameters:
- data: InferenceData object with log_likelihood group
- var_name (str, optional): Variable name for likelihood
- reff (array, optional): Relative effective sample size
- scale (str): Scale for IC ('log', 'negative_log', 'deviance')
- pointwise (bool): Return pointwise LOO values
- dask_kwargs (dict, optional): Dask computation options
Returns:
- loo_result: LOO-CV results with ELPD, SE, and diagnostics
"""
# Compute log-likelihood for LOO
with pm.Model() as model:
# Model definition...
trace = pm.sample()
# Compute log-likelihood
log_likelihood = pm.compute_log_likelihood(trace, model=model)
# LOO cross-validation
loo_result = pm.loo(trace)
print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")
print(f"LOO IC: {loo_result.loo:.2f}")
print(f"p_loo (effective parameters): {loo_result.p_loo:.2f}")
# Check Pareto k diagnostic
high_k = (loo_result.pareto_k > 0.7).sum()
if high_k > 0:
print(f"Warning: {high_k} observations have high Pareto k values")def waic(data, var_name=None, scale=None, pointwise=False, dask_kwargs=None):
"""
Compute Watanabe-Akaike Information Criterion (WAIC).
Parameters:
- data: InferenceData object with log_likelihood group
- var_name (str, optional): Variable name for likelihood
- scale (str): Scale for IC ('log', 'negative_log', 'deviance')
- pointwise (bool): Return pointwise WAIC values
- dask_kwargs (dict, optional): Dask computation options
Returns:
- waic_result: WAIC results with ELPD, SE, and effective parameters
"""
# WAIC computation
waic_result = pm.waic(trace)
print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")
print(f"WAIC: {waic_result.waic:.2f}")
print(f"p_waic (effective parameters): {waic_result.p_waic:.2f}")
# Pointwise WAIC for outlier detection
waic_pointwise = pm.waic(trace, pointwise=True)
outlier_threshold = waic_pointwise.waic_i.mean() + 2 * waic_pointwise.waic_i.std()
outliers = waic_pointwise.waic_i > outlier_threshold
print(f"Potential outliers: {outliers.sum()} observations")def compare(compare_dict, ic=None, method='stacking', b_samples=1000,
alpha=0.05, seed=None, scale=None):
"""
Compare models using information criteria.
Parameters:
- compare_dict (dict): Dictionary of {model_name: InferenceData}
- ic (str): Information criterion ('loo', 'waic')
- method (str): Comparison method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')
- b_samples (int): Bootstrap samples for SE estimation
- alpha (float): Significance level for intervals
- seed (int): Random seed
- scale (str): Scale for IC reporting
Returns:
- comparison_df: DataFrame with model comparison results
"""
# Compare multiple models
models = {
'linear': linear_trace,
'quadratic': quadratic_trace,
'cubic': cubic_trace
}
comparison = pm.compare(models, ic='loo')
print("Model Comparison (LOO):")
print(comparison)
# Model weights from stacking
print("\nModel weights:")
for model, weight in zip(comparison.index, comparison.weight):
print(f" {model}: {weight:.3f}")
# Automatically select best model
best_model = comparison.index[0] # First row is best
print(f"\nBest model: {best_model}")def compute_log_likelihood(idata=None, *, model=None, var_names=None,
extend_inferencedata=True, progressbar=True):
"""
Compute pointwise log-likelihood values.
Parameters:
- idata: InferenceData object with posterior samples
- model: PyMC model (default: current context)
- var_names (list, optional): Observed variables to compute likelihood for
- extend_inferencedata (bool): Add results to InferenceData
- progressbar (bool): Show progress bar
Returns:
- log_likelihood: Log-likelihood values for each observation and posterior sample
"""
with pm.Model() as model:
# Model with likelihood...
trace = pm.sample()
# Compute log-likelihood
log_lik = pm.compute_log_likelihood(trace, model=model)
# Access log-likelihood values
ll_values = trace.log_likelihood # Added to InferenceData
print(f"Log-likelihood shape: {ll_values['y_obs'].shape}") # (chains, draws, observations)
# Total log-likelihood per sample
total_ll = ll_values['y_obs'].sum(dim='y_obs_dim_0')
print(f"Total log-likelihood range: {total_ll.min():.2f} to {total_ll.max():.2f}")def compute_log_prior(idata=None, *, model=None, var_names=None,
extend_inferencedata=True, progressbar=True):
"""
Compute log-prior density values.
Parameters:
- idata: InferenceData object with posterior samples
- model: PyMC model (default: current context)
- var_names (list, optional): Variables to compute log-prior for
- extend_inferencedata (bool): Add results to InferenceData
- progressbar (bool): Show progress bar
Returns:
- log_prior: Log-prior values for each variable and posterior sample
"""
# Compute log-prior
log_prior = pm.compute_log_prior(trace, model=model)
# Access log-prior values
prior_values = trace.log_prior
print("Log-prior components:")
for var_name in prior_values.data_vars:
values = prior_values[var_name]
print(f" {var_name}: mean = {values.mean():.3f}, std = {values.std():.3f}")
# Total log-prior per sample
total_prior = sum(prior_values[var].sum() for var in prior_values.data_vars)# Summary statistics through ArviZ integration
summary_stats = pm.summary(trace, var_names=['alpha', 'beta'])
print("Posterior Summary:")
print(summary_stats)
# Custom summary with specific quantiles
custom_summary = pm.summary(trace,
stat_funcs={'median': np.median,
'mad': lambda x: np.median(np.abs(x - np.median(x)))},
extend=True)
# Round summary for reporting
rounded_summary = pm.summary(trace, round_to=3)# Posterior predictive sampling for model checking
with pm.Model() as model:
# Model definition...
trace = pm.sample()
# Posterior predictive samples
post_pred = pm.sample_posterior_predictive(trace, predictions=True)
# Compare observed vs predicted
observed = post_pred.observed_data['y_obs']
predicted = post_pred.posterior_predictive['y_obs']
# T-test statistic for checking
def t_statistic(y):
return (y.mean() - observed.mean()) / (y.std() / np.sqrt(len(y)))
# Compute test statistic for observed and predicted
t_obs = t_statistic(observed.values)
t_pred = [t_statistic(pred_sample) for pred_sample in predicted.values.reshape(-1, len(observed))]
# Bayesian p-value
p_value = np.mean(np.abs(t_pred) >= np.abs(t_obs))
print(f"Bayesian p-value for mean difference: {p_value:.3f}")# Access sampler statistics for energy diagnostics
sampler_stats = trace.get_sampler_stats()
# Energy statistics
energy = sampler_stats['energy']
energy_diff = np.diff(energy, axis=1) # Energy differences between steps
# Check for energy problems
mean_energy_diff = energy_diff.mean()
if abs(mean_energy_diff) > 0.2:
print(f"Warning: Large energy differences (mean = {mean_energy_diff:.3f})")
# Divergences
diverging = sampler_stats['diverging']
n_diverging = diverging.sum()
if n_diverging > 0:
print(f"Warning: {n_diverging} divergent transitions detected")
# Tree depth
treedepth = sampler_stats['treedepth']
max_treedepth = sampler_stats['max_treedepth']
saturated_trees = (treedepth >= max_treedepth).sum()
if saturated_trees > 0:
print(f"Warning: {saturated_trees} saturated trees (increase max_treedepth)")def compute_split_rhat(trace, var_name):
"""Compute split R-hat manually for understanding."""
# Get samples for variable
samples = trace.posterior[var_name].values # Shape: (chains, draws, ...)
n_chains, n_draws = samples.shape[:2]
# Split each chain in half
first_half = samples[:, :n_draws//2]
second_half = samples[:, n_draws//2:]
# Combine split chains
split_samples = np.concatenate([first_half, second_half], axis=0)
# Between-chain variance
chain_means = split_samples.mean(axis=1)
overall_mean = chain_means.mean()
B = n_draws//2 * np.var(chain_means, ddof=1)
# Within-chain variance
chain_vars = split_samples.var(axis=1, ddof=1)
W = chain_vars.mean()
# Marginal posterior variance estimate
var_hat = (n_draws//2 - 1) / (n_draws//2) * W + B / (n_draws//2)
# R-hat
rhat = np.sqrt(var_hat / W)
return rhat
# Usage
manual_rhat = compute_split_rhat(trace, 'alpha')
print(f"Manual R-hat calculation: {manual_rhat:.4f}")def rank_normalized_split_rhat(data, var_names=None):
"""
Compute rank-normalized split R-hat (more robust version).
Parameters:
- data: InferenceData object
- var_names (list, optional): Variables to analyze
Returns:
- rhat_rank: Rank-normalized R-hat values
"""
# More robust R-hat using rank normalization
rhat_rank = pm.rank_normalized_split_rhat(trace)
print("Rank-normalized R-hat:")
for var, rhat_val in rhat_rank.items():
print(f" {var}: {rhat_val:.4f}")
if rhat_val > 1.01:
print(f" Warning: {var} may not have converged")def full_convergence_check(trace, model_name="Model"):
"""Comprehensive convergence assessment."""
print(f"=== Convergence Diagnostics for {model_name} ===")
# R-hat
rhat_vals = pm.rhat(trace)
max_rhat = max(rhat_vals.values())
print(f"Max R-hat: {max_rhat:.4f}")
# Effective sample size
ess_bulk = pm.ess(trace, method='bulk')
ess_tail = pm.ess(trace, method='tail')
min_ess_bulk = min(ess_bulk.values())
min_ess_tail = min(ess_tail.values())
print(f"Min ESS (bulk): {min_ess_bulk:.0f}")
print(f"Min ESS (tail): {min_ess_tail:.0f}")
# Sampler diagnostics
n_diverging = trace.get_sampler_stats('diverging').sum()
print(f"Diverging transitions: {n_diverging}")
# Overall assessment
converged = (max_rhat < 1.01 and min_ess_bulk > 400 and
min_ess_tail > 400 and n_diverging == 0)
print(f"Overall convergence: {'✓ PASS' if converged else '✗ FAIL'}")
return converged
# Usage
convergence_ok = full_convergence_check(trace, "Regression Model")def assess_model_quality(trace, observed_data, model):
"""Comprehensive model quality assessment."""
print("=== Model Quality Assessment ===")
# Information criteria
loo_result = pm.loo(trace)
waic_result = pm.waic(trace)
print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")
print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")
# Check for high Pareto k values
high_k = (loo_result.pareto_k > 0.7).sum()
if high_k > 0:
print(f"Warning: {high_k} observations have unreliable LOO estimates")
# Posterior predictive checks
post_pred = pm.sample_posterior_predictive(trace, model=model)
# Simple residual check
y_obs = observed_data
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
residuals = y_obs - y_pred_mean
print(f"Mean absolute residual: {np.abs(residuals).mean():.3f}")
print(f"Residual std: {residuals.std():.3f}")
return loo_result, waic_result, residuals
# Usage
loo, waic, residuals = assess_model_quality(trace, y_data, model)PyMC's statistics and diagnostics framework, built on ArviZ integration, provides essential tools for validating Bayesian models and ensuring reliable inference results.
Install with Tessl CLI
npx tessl i tessl/pypi-pymc