0
# PyMC Statistics and Diagnostics
1
2
PyMC provides comprehensive statistical functions and convergence diagnostics for Bayesian analysis, primarily through integration with ArviZ. The library offers tools for model validation, convergence assessment, and posterior analysis.
3
4
## Convergence Diagnostics
5
6
PyMC exposes key diagnostic functions from ArviZ for assessing MCMC convergence:
7
8
### R-hat Statistic
9
10
```python { .api }
11
def rhat(data, var_names=None, method='rank', dask_kwargs=None):
12
"""
13
Compute R-hat convergence diagnostic (Gelman-Rubin statistic).
14
15
Parameters:
16
- data: InferenceData object or trace
17
- var_names (list, optional): Variables to analyze
18
- method (str): Method for computation ('rank', 'split', 'folded')
19
- dask_kwargs (dict, optional): Dask computation options
20
21
Returns:
22
- rhat_values: R-hat statistics for each variable
23
"""
24
25
import pymc as pm
26
27
# Compute R-hat for all variables
28
with pm.Model() as model:
29
# Model definition and sampling...
30
trace = pm.sample()
31
32
rhat_stats = pm.rhat(trace)
33
print("R-hat diagnostics:")
34
for var, rhat_val in rhat_stats.items():
35
print(f" {var}: {rhat_val:.4f}")
36
37
# R-hat for specific variables only
38
rhat_subset = pm.rhat(trace, var_names=['alpha', 'beta'])
39
40
# Check convergence (R-hat should be < 1.01)
41
converged = all(rhat_val < 1.01 for rhat_val in rhat_stats.values())
42
```
43
44
### Effective Sample Size
45
46
```python { .api }
47
def effective_sample_size(data, var_names=None, method='bulk',
48
relative=False, dask_kwargs=None):
49
"""
50
Compute effective sample size (ESS).
51
52
Parameters:
53
- data: InferenceData object or trace
54
- var_names (list, optional): Variables to analyze
55
- method (str): ESS method ('bulk', 'tail', 'quantile', 'mean', 'sd')
56
- relative (bool): Return relative ESS (ESS/N)
57
- dask_kwargs (dict, optional): Dask computation options
58
59
Returns:
60
- ess_values: Effective sample size for each variable
61
"""
62
63
# Bulk ESS (measures efficiency in central posterior)
64
bulk_ess = pm.ess(trace, method='bulk')
65
66
# Tail ESS (measures efficiency in posterior tails)
67
tail_ess = pm.ess(trace, method='tail')
68
69
# Relative ESS (as fraction of total samples)
70
rel_ess = pm.ess(trace, relative=True)
71
72
print("Effective Sample Size (bulk):")
73
for var, ess_val in bulk_ess.items():
74
print(f" {var}: {ess_val:.0f}")
75
76
# Check adequacy (ESS should be > 400 for reliable inference)
77
adequate_ess = all(ess_val > 400 for ess_val in bulk_ess.values())
78
```
79
80
### Monte Carlo Standard Error
81
82
```python { .api }
83
def mcse(data, var_names=None, method='mean', dask_kwargs=None):
84
"""
85
Compute Monte Carlo standard error.
86
87
Parameters:
88
- data: InferenceData object or trace
89
- var_names (list, optional): Variables to analyze
90
- method (str): Statistic to compute MCSE for ('mean', 'sd', 'quantile')
91
- dask_kwargs (dict, optional): Dask computation options
92
93
Returns:
94
- mcse_values: Monte Carlo standard errors
95
"""
96
97
# MCSE for posterior means
98
mcse_mean = pm.mcse(trace, method='mean')
99
100
# MCSE for posterior standard deviations
101
mcse_sd = pm.mcse(trace, method='sd')
102
103
# MCSE for quantiles
104
mcse_quantile = pm.mcse(trace, method='quantile')
105
106
print("Monte Carlo Standard Error (mean):")
107
for var, mcse_val in mcse_mean.items():
108
print(f" {var}: {mcse_val:.6f}")
109
```
110
111
## Model Comparison
112
113
### Leave-One-Out Cross-Validation
114
115
```python { .api }
116
def loo(data, var_name=None, reff=None, scale=None, pointwise=False,
117
dask_kwargs=None):
118
"""
119
Compute leave-one-out (LOO) cross-validation using Pareto smoothed importance sampling.
120
121
Parameters:
122
- data: InferenceData object with log_likelihood group
123
- var_name (str, optional): Variable name for likelihood
124
- reff (array, optional): Relative effective sample size
125
- scale (str): Scale for IC ('log', 'negative_log', 'deviance')
126
- pointwise (bool): Return pointwise LOO values
127
- dask_kwargs (dict, optional): Dask computation options
128
129
Returns:
130
- loo_result: LOO-CV results with ELPD, SE, and diagnostics
131
"""
132
133
# Compute log-likelihood for LOO
134
with pm.Model() as model:
135
# Model definition...
136
trace = pm.sample()
137
138
# Compute log-likelihood
139
log_likelihood = pm.compute_log_likelihood(trace, model=model)
140
141
# LOO cross-validation
142
loo_result = pm.loo(trace)
143
print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")
144
print(f"LOO IC: {loo_result.loo:.2f}")
145
print(f"p_loo (effective parameters): {loo_result.p_loo:.2f}")
146
147
# Check Pareto k diagnostic
148
high_k = (loo_result.pareto_k > 0.7).sum()
149
if high_k > 0:
150
print(f"Warning: {high_k} observations have high Pareto k values")
151
```
152
153
### Watanabe-Akaike Information Criterion
154
155
```python { .api }
156
def waic(data, var_name=None, scale=None, pointwise=False, dask_kwargs=None):
157
"""
158
Compute Watanabe-Akaike Information Criterion (WAIC).
159
160
Parameters:
161
- data: InferenceData object with log_likelihood group
162
- var_name (str, optional): Variable name for likelihood
163
- scale (str): Scale for IC ('log', 'negative_log', 'deviance')
164
- pointwise (bool): Return pointwise WAIC values
165
- dask_kwargs (dict, optional): Dask computation options
166
167
Returns:
168
- waic_result: WAIC results with ELPD, SE, and effective parameters
169
"""
170
171
# WAIC computation
172
waic_result = pm.waic(trace)
173
print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")
174
print(f"WAIC: {waic_result.waic:.2f}")
175
print(f"p_waic (effective parameters): {waic_result.p_waic:.2f}")
176
177
# Pointwise WAIC for outlier detection
178
waic_pointwise = pm.waic(trace, pointwise=True)
179
outlier_threshold = waic_pointwise.waic_i.mean() + 2 * waic_pointwise.waic_i.std()
180
outliers = waic_pointwise.waic_i > outlier_threshold
181
print(f"Potential outliers: {outliers.sum()} observations")
182
```
183
184
### Model Comparison Framework
185
186
```python { .api }
187
def compare(compare_dict, ic=None, method='stacking', b_samples=1000,
188
alpha=0.05, seed=None, scale=None):
189
"""
190
Compare models using information criteria.
191
192
Parameters:
193
- compare_dict (dict): Dictionary of {model_name: InferenceData}
194
- ic (str): Information criterion ('loo', 'waic')
195
- method (str): Comparison method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')
196
- b_samples (int): Bootstrap samples for SE estimation
197
- alpha (float): Significance level for intervals
198
- seed (int): Random seed
199
- scale (str): Scale for IC reporting
200
201
Returns:
202
- comparison_df: DataFrame with model comparison results
203
"""
204
205
# Compare multiple models
206
models = {
207
'linear': linear_trace,
208
'quadratic': quadratic_trace,
209
'cubic': cubic_trace
210
}
211
212
comparison = pm.compare(models, ic='loo')
213
print("Model Comparison (LOO):")
214
print(comparison)
215
216
# Model weights from stacking
217
print("\nModel weights:")
218
for model, weight in zip(comparison.index, comparison.weight):
219
print(f" {model}: {weight:.3f}")
220
221
# Automatically select best model
222
best_model = comparison.index[0] # First row is best
223
print(f"\nBest model: {best_model}")
224
```
225
226
## Log-Likelihood and Prior Computation
227
228
### Log-Likelihood Calculation
229
230
```python { .api }
231
def compute_log_likelihood(idata=None, *, model=None, var_names=None,
232
extend_inferencedata=True, progressbar=True):
233
"""
234
Compute pointwise log-likelihood values.
235
236
Parameters:
237
- idata: InferenceData object with posterior samples
238
- model: PyMC model (default: current context)
239
- var_names (list, optional): Observed variables to compute likelihood for
240
- extend_inferencedata (bool): Add results to InferenceData
241
- progressbar (bool): Show progress bar
242
243
Returns:
244
- log_likelihood: Log-likelihood values for each observation and posterior sample
245
"""
246
247
with pm.Model() as model:
248
# Model with likelihood...
249
trace = pm.sample()
250
251
# Compute log-likelihood
252
log_lik = pm.compute_log_likelihood(trace, model=model)
253
254
# Access log-likelihood values
255
ll_values = trace.log_likelihood # Added to InferenceData
256
print(f"Log-likelihood shape: {ll_values['y_obs'].shape}") # (chains, draws, observations)
257
258
# Total log-likelihood per sample
259
total_ll = ll_values['y_obs'].sum(dim='y_obs_dim_0')
260
print(f"Total log-likelihood range: {total_ll.min():.2f} to {total_ll.max():.2f}")
261
```
262
263
### Log-Prior Calculation
264
265
```python { .api }
266
def compute_log_prior(idata=None, *, model=None, var_names=None,
267
extend_inferencedata=True, progressbar=True):
268
"""
269
Compute log-prior density values.
270
271
Parameters:
272
- idata: InferenceData object with posterior samples
273
- model: PyMC model (default: current context)
274
- var_names (list, optional): Variables to compute log-prior for
275
- extend_inferencedata (bool): Add results to InferenceData
276
- progressbar (bool): Show progress bar
277
278
Returns:
279
- log_prior: Log-prior values for each variable and posterior sample
280
"""
281
282
# Compute log-prior
283
log_prior = pm.compute_log_prior(trace, model=model)
284
285
# Access log-prior values
286
prior_values = trace.log_prior
287
print("Log-prior components:")
288
for var_name in prior_values.data_vars:
289
values = prior_values[var_name]
290
print(f" {var_name}: mean = {values.mean():.3f}, std = {values.std():.3f}")
291
292
# Total log-prior per sample
293
total_prior = sum(prior_values[var].sum() for var in prior_values.data_vars)
294
```
295
296
## Posterior Analysis Utilities
297
298
### Summary Statistics
299
300
```python { .api }
301
# Summary statistics through ArviZ integration
302
summary_stats = pm.summary(trace, var_names=['alpha', 'beta'])
303
print("Posterior Summary:")
304
print(summary_stats)
305
306
# Custom summary with specific quantiles
307
custom_summary = pm.summary(trace,
308
stat_funcs={'median': np.median,
309
'mad': lambda x: np.median(np.abs(x - np.median(x)))},
310
extend=True)
311
312
# Round summary for reporting
313
rounded_summary = pm.summary(trace, round_to=3)
314
```
315
316
### Posterior Predictive Checks
317
318
```python { .api }
319
# Posterior predictive sampling for model checking
320
with pm.Model() as model:
321
# Model definition...
322
trace = pm.sample()
323
324
# Posterior predictive samples
325
post_pred = pm.sample_posterior_predictive(trace, predictions=True)
326
327
# Compare observed vs predicted
328
observed = post_pred.observed_data['y_obs']
329
predicted = post_pred.posterior_predictive['y_obs']
330
331
# T-test statistic for checking
332
def t_statistic(y):
333
return (y.mean() - observed.mean()) / (y.std() / np.sqrt(len(y)))
334
335
# Compute test statistic for observed and predicted
336
t_obs = t_statistic(observed.values)
337
t_pred = [t_statistic(pred_sample) for pred_sample in predicted.values.reshape(-1, len(observed))]
338
339
# Bayesian p-value
340
p_value = np.mean(np.abs(t_pred) >= np.abs(t_obs))
341
print(f"Bayesian p-value for mean difference: {p_value:.3f}")
342
```
343
344
## Advanced Diagnostics
345
346
### Energy Diagnostics
347
348
```python { .api }
349
# Access sampler statistics for energy diagnostics
350
sampler_stats = trace.get_sampler_stats()
351
352
# Energy statistics
353
energy = sampler_stats['energy']
354
energy_diff = np.diff(energy, axis=1) # Energy differences between steps
355
356
# Check for energy problems
357
mean_energy_diff = energy_diff.mean()
358
if abs(mean_energy_diff) > 0.2:
359
print(f"Warning: Large energy differences (mean = {mean_energy_diff:.3f})")
360
361
# Divergences
362
diverging = sampler_stats['diverging']
363
n_diverging = diverging.sum()
364
if n_diverging > 0:
365
print(f"Warning: {n_diverging} divergent transitions detected")
366
367
# Tree depth
368
treedepth = sampler_stats['treedepth']
369
max_treedepth = sampler_stats['max_treedepth']
370
saturated_trees = (treedepth >= max_treedepth).sum()
371
if saturated_trees > 0:
372
print(f"Warning: {saturated_trees} saturated trees (increase max_treedepth)")
373
```
374
375
### Custom Diagnostics
376
377
```python { .api }
378
def compute_split_rhat(trace, var_name):
379
"""Compute split R-hat manually for understanding."""
380
381
# Get samples for variable
382
samples = trace.posterior[var_name].values # Shape: (chains, draws, ...)
383
n_chains, n_draws = samples.shape[:2]
384
385
# Split each chain in half
386
first_half = samples[:, :n_draws//2]
387
second_half = samples[:, n_draws//2:]
388
389
# Combine split chains
390
split_samples = np.concatenate([first_half, second_half], axis=0)
391
392
# Between-chain variance
393
chain_means = split_samples.mean(axis=1)
394
overall_mean = chain_means.mean()
395
B = n_draws//2 * np.var(chain_means, ddof=1)
396
397
# Within-chain variance
398
chain_vars = split_samples.var(axis=1, ddof=1)
399
W = chain_vars.mean()
400
401
# Marginal posterior variance estimate
402
var_hat = (n_draws//2 - 1) / (n_draws//2) * W + B / (n_draws//2)
403
404
# R-hat
405
rhat = np.sqrt(var_hat / W)
406
407
return rhat
408
409
# Usage
410
manual_rhat = compute_split_rhat(trace, 'alpha')
411
print(f"Manual R-hat calculation: {manual_rhat:.4f}")
412
```
413
414
### Rank Normalization Diagnostics
415
416
```python { .api }
417
def rank_normalized_split_rhat(data, var_names=None):
418
"""
419
Compute rank-normalized split R-hat (more robust version).
420
421
Parameters:
422
- data: InferenceData object
423
- var_names (list, optional): Variables to analyze
424
425
Returns:
426
- rhat_rank: Rank-normalized R-hat values
427
"""
428
429
# More robust R-hat using rank normalization
430
rhat_rank = pm.rank_normalized_split_rhat(trace)
431
print("Rank-normalized R-hat:")
432
for var, rhat_val in rhat_rank.items():
433
print(f" {var}: {rhat_val:.4f}")
434
if rhat_val > 1.01:
435
print(f" Warning: {var} may not have converged")
436
```
437
438
## Diagnostic Workflows
439
440
### Comprehensive Convergence Check
441
442
```python { .api }
443
def full_convergence_check(trace, model_name="Model"):
444
"""Comprehensive convergence assessment."""
445
446
print(f"=== Convergence Diagnostics for {model_name} ===")
447
448
# R-hat
449
rhat_vals = pm.rhat(trace)
450
max_rhat = max(rhat_vals.values())
451
print(f"Max R-hat: {max_rhat:.4f}")
452
453
# Effective sample size
454
ess_bulk = pm.ess(trace, method='bulk')
455
ess_tail = pm.ess(trace, method='tail')
456
min_ess_bulk = min(ess_bulk.values())
457
min_ess_tail = min(ess_tail.values())
458
print(f"Min ESS (bulk): {min_ess_bulk:.0f}")
459
print(f"Min ESS (tail): {min_ess_tail:.0f}")
460
461
# Sampler diagnostics
462
n_diverging = trace.get_sampler_stats('diverging').sum()
463
print(f"Diverging transitions: {n_diverging}")
464
465
# Overall assessment
466
converged = (max_rhat < 1.01 and min_ess_bulk > 400 and
467
min_ess_tail > 400 and n_diverging == 0)
468
469
print(f"Overall convergence: {'✓ PASS' if converged else '✗ FAIL'}")
470
471
return converged
472
473
# Usage
474
convergence_ok = full_convergence_check(trace, "Regression Model")
475
```
476
477
### Model Quality Assessment
478
479
```python { .api }
480
def assess_model_quality(trace, observed_data, model):
481
"""Comprehensive model quality assessment."""
482
483
print("=== Model Quality Assessment ===")
484
485
# Information criteria
486
loo_result = pm.loo(trace)
487
waic_result = pm.waic(trace)
488
489
print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")
490
print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")
491
492
# Check for high Pareto k values
493
high_k = (loo_result.pareto_k > 0.7).sum()
494
if high_k > 0:
495
print(f"Warning: {high_k} observations have unreliable LOO estimates")
496
497
# Posterior predictive checks
498
post_pred = pm.sample_posterior_predictive(trace, model=model)
499
500
# Simple residual check
501
y_obs = observed_data
502
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
503
residuals = y_obs - y_pred_mean
504
505
print(f"Mean absolute residual: {np.abs(residuals).mean():.3f}")
506
print(f"Residual std: {residuals.std():.3f}")
507
508
return loo_result, waic_result, residuals
509
510
# Usage
511
loo, waic, residuals = assess_model_quality(trace, y_data, model)
512
```
513
514
PyMC's statistics and diagnostics framework, built on ArviZ integration, provides essential tools for validating Bayesian models and ensuring reliable inference results.