0
# Statistics and Plotting (ArviZ Integration)
1
2
PyMC3 integrates tightly with ArviZ for comprehensive Bayesian analysis, model diagnostics, and publication-quality visualizations. The stats and plots modules delegate to ArviZ while providing PyMC3-specific functionality and convenient aliases for common workflows.
3
4
## Capabilities
5
6
### Convergence Diagnostics
7
8
Functions for assessing MCMC convergence and sample quality through `pymc3.stats.*`.
9
10
```python { .api }
11
def r_hat(trace, var_names=None, method='rank'):
12
"""
13
Compute R-hat convergence diagnostic.
14
15
Measures between-chain and within-chain variance to assess
16
convergence across multiple MCMC chains. Values close to 1.0
17
indicate good convergence.
18
19
Parameters:
20
- trace: InferenceData or MultiTrace, posterior samples
21
- var_names: list, variables to analyze (all if None)
22
- method: str, computation method ('rank', 'split', 'folded')
23
24
Returns:
25
- dict or array: R-hat values by variable
26
27
Interpretation:
28
- R_hat < 1.01: Excellent convergence
29
- R_hat < 1.1: Good convergence
30
- R_hat > 1.1: Poor convergence, need more samples
31
"""
32
33
def ess(trace, var_names=None, method='bulk'):
34
"""
35
Compute effective sample size.
36
37
Estimates the number of independent samples, accounting for
38
autocorrelation in MCMC chains. Higher values indicate better
39
mixing and more efficient sampling.
40
41
Parameters:
42
- trace: InferenceData or MultiTrace, posterior samples
43
- var_names: list, variables to analyze
44
- method: str, ESS type ('bulk', 'tail', 'mean', 'sd', 'quantile')
45
46
Returns:
47
- dict or array: effective sample sizes
48
49
Guidelines:
50
- ESS > 400: Generally sufficient for posterior inference
51
- ESS > 100: Minimum for reasonable estimates
52
- ESS < 100: Increase sampling or improve model
53
"""
54
55
def mcse(trace, var_names=None, method='mean', prob=None):
56
"""
57
Monte Carlo standard error of estimates.
58
59
Measures uncertainty in posterior estimates due to finite
60
sampling, helping determine if more samples are needed.
61
62
Parameters:
63
- trace: InferenceData or MultiTrace, posterior samples
64
- var_names: list, variables to analyze
65
- method: str, estimate type ('mean', 'sd', 'quantile')
66
- prob: float, probability for quantile MCSE
67
68
Returns:
69
- dict: MCSE values by variable
70
"""
71
72
def geweke(trace, var_names=None, first=0.1, last=0.5, intervals=20):
73
"""
74
Geweke convergence diagnostic.
75
76
Compares means from early and late portions of chains to
77
assess within-chain convergence and stationarity.
78
79
Parameters:
80
- trace: InferenceData or MultiTrace, posterior samples
81
- var_names: list, variables to test
82
- first: float, fraction for early portion
83
- last: float, fraction for late portion
84
- intervals: int, number of test intervals
85
86
Returns:
87
- dict: Geweke statistics by variable
88
"""
89
```
90
91
### Model Comparison
92
93
Information criteria and cross-validation for Bayesian model selection.
94
95
```python { .api }
96
def compare(models, ic='waic', method='stacking', b_samples=1000,
97
alpha=1, seed=None, round_to=2):
98
"""
99
Compare multiple models using information criteria.
100
101
Ranks models by predictive performance using WAIC, LOO-CV,
102
or other criteria, with model weights and standard errors.
103
104
Parameters:
105
- models: dict, mapping model names to InferenceData objects
106
- ic: str, information criterion ('waic', 'loo')
107
- method: str, weighting method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')
108
- b_samples: int, samples for Bootstrap weighting
109
- alpha: float, concentration parameter for pseudo-BMA
110
- seed: int, random seed for reproducibility
111
- round_to: int, decimal places for results
112
113
Returns:
114
- DataFrame: model comparison results with ranks and weights
115
116
Columns:
117
- rank: model ranking (0 = best)
118
- elpd_*: expected log pointwise predictive density
119
- p_*: effective number of parameters
120
- d_*: difference from best model
121
- weight: model averaging weights
122
- se: standard error of differences
123
- dse: standard error of difference from best
124
"""
125
126
def waic(trace, model=None, pointwise=False, scale='deviance'):
127
"""
128
Watanabe-Akaike Information Criterion.
129
130
Estimates out-of-sample predictive performance using
131
within-sample log-likelihood with penalty for overfitting.
132
133
Parameters:
134
- trace: InferenceData or MultiTrace, posterior samples
135
- model: Model, model context (current if None)
136
- pointwise: bool, return pointwise WAIC values
137
- scale: str, return scale ('deviance' or 'log')
138
139
Returns:
140
- ELPDData: WAIC results with components and diagnostics
141
142
Components:
143
- elpd_waic: expected log pointwise predictive density
144
- p_waic: effective number of parameters
145
- waic: -2 * elpd_waic (lower is better)
146
- se: standard error of WAIC
147
"""
148
149
def loo(trace, model=None, pointwise=False, reff=None, scale='deviance'):
150
"""
151
Pareto Smoothed Importance Sampling Leave-One-Out Cross-Validation.
152
153
Estimates out-of-sample performance using leave-one-out
154
cross-validation approximated by importance sampling.
155
156
Parameters:
157
- trace: InferenceData or MultiTrace, posterior samples
158
- model: Model, model context
159
- pointwise: bool, return pointwise LOO values
160
- reff: array, relative effective sample sizes
161
- scale: str, return scale ('deviance' or 'log')
162
163
Returns:
164
- ELPDData: LOO-CV results with Pareto diagnostics
165
166
Diagnostics:
167
- Pareto k < 0.5: Good approximation
168
- Pareto k < 0.7: Okay approximation
169
- Pareto k > 0.7: Poor approximation, use exact CV
170
"""
171
172
def loo_pit(idata, y=None, y_hat=None, log_weights=None):
173
"""
174
Leave-one-out probability integral transform.
175
176
Calibration check for posterior predictive distributions
177
using LOO-PIT values that should be uniform if well-calibrated.
178
179
Parameters:
180
- idata: InferenceData, posterior and predictions
181
- y: array, observed data (from idata if None)
182
- y_hat: array, posterior predictive samples
183
- log_weights: array, importance sampling weights
184
185
Returns:
186
- array: LOO-PIT values for calibration assessment
187
"""
188
```
189
190
### Summary Statistics
191
192
Posterior summary and descriptive statistics.
193
194
```python { .api }
195
def summary(trace, var_names=None, stat_funcs=None, extend=True,
196
credible_interval=0.94, round_to=2, kind='stats'):
197
"""
198
Comprehensive posterior summary statistics.
199
200
Provides means, standard deviations, credible intervals,
201
and convergence diagnostics for all model parameters.
202
203
Parameters:
204
- trace: InferenceData or MultiTrace, posterior samples
205
- var_names: list, variables to summarize (all if None)
206
- stat_funcs: dict, custom summary functions
207
- extend: bool, include convergence diagnostics
208
- credible_interval: float, credible interval width
209
- round_to: int, decimal places
210
- kind: str, summary type ('stats', 'diagnostics')
211
212
Returns:
213
- DataFrame: comprehensive parameter summary
214
215
Columns:
216
- mean: posterior mean
217
- sd: posterior standard deviation
218
- hdi_3%/hdi_97%: highest density interval bounds
219
- mcse_mean: MCSE of mean
220
- mcse_sd: MCSE of standard deviation
221
- ess_bulk/ess_tail: effective sample sizes
222
- r_hat: R-hat convergence diagnostic
223
"""
224
225
def describe(trace, var_names=None, include_ci=True, ci_prob=0.94):
226
"""
227
Descriptive statistics for posterior distributions.
228
229
Parameters:
230
- trace: InferenceData or MultiTrace, posterior samples
231
- var_names: list, variables to describe
232
- include_ci: bool, include credible intervals
233
- ci_prob: float, credible interval probability
234
235
Returns:
236
- DataFrame: descriptive statistics
237
"""
238
239
def quantiles(x, qlist=(0.025, 0.25, 0.5, 0.75, 0.975)):
240
"""
241
Compute quantiles of posterior samples.
242
243
Parameters:
244
- x: array, samples
245
- qlist: tuple, quantile probabilities
246
247
Returns:
248
- dict: quantile values
249
"""
250
251
def hdi(x, credible_interval=0.94, circular=False):
252
"""
253
Highest Density Interval (HDI).
254
255
Computes the shortest interval containing specified
256
probability mass of the posterior distribution.
257
258
Parameters:
259
- x: array, posterior samples
260
- credible_interval: float, interval probability
261
- circular: bool, circular data (angles)
262
263
Returns:
264
- array: [lower_bound, upper_bound]
265
"""
266
```
267
268
### Posterior Analysis
269
270
Advanced posterior analysis and derived quantities.
271
272
```python { .api }
273
def autocorr(trace, var_names=None, max_lag=100):
274
"""
275
Autocorrelation function of MCMC chains.
276
277
Measures correlation between samples at different lags
278
to assess mixing and effective sample size.
279
280
Parameters:
281
- trace: InferenceData or MultiTrace, samples
282
- var_names: list, variables to analyze
283
- max_lag: int, maximum lag to compute
284
285
Returns:
286
- dict: autocorrelation functions by variable
287
"""
288
289
def make_ufunc(func, nin=1, nout=1, **kwargs):
290
"""
291
Create universal function for posterior analysis.
292
293
Converts regular functions into universal functions
294
that work efficiently on posterior sample arrays.
295
296
Parameters:
297
- func: callable, function to convert
298
- nin: int, number of inputs
299
- nout: int, number of outputs
300
- kwargs: additional ufunc arguments
301
302
Returns:
303
- ufunc: universal function
304
"""
305
306
def from_dict(posterior_dict, coords=None, dims=None):
307
"""
308
Create InferenceData from dictionary of arrays.
309
310
Parameters:
311
- posterior_dict: dict, posterior samples by variable
312
- coords: dict, coordinate values
313
- dims: dict, dimension names by variable
314
315
Returns:
316
- InferenceData: formatted inference data
317
"""
318
```
319
320
### Plotting Functions
321
322
Comprehensive visualization capabilities through ArviZ integration via `pymc3.plots.*`.
323
324
```python { .api }
325
def plot_trace(trace, var_names=None, coords=None, divergences='auto',
326
figsize=None, rug=False, lines=None, compact=True,
327
combined=False, legend=False, plot_kwargs=None,
328
fill_kwargs=None, rug_kwargs=None, **kwargs):
329
"""
330
Trace plots showing MCMC sampling paths and marginal distributions.
331
332
Essential diagnostic plot combining time series of samples
333
with marginal posterior distributions for visual convergence assessment.
334
335
Parameters:
336
- trace: InferenceData, posterior samples
337
- var_names: list, variables to plot
338
- coords: dict, coordinate slices for multidimensional variables
339
- divergences: str or bool, highlight divergent transitions
340
- figsize: tuple, figure size
341
- rug: bool, add rug plot to marginals
342
- lines: dict, reference lines to overlay
343
- compact: bool, compact layout
344
- combined: bool, combine all chains
345
- legend: bool, show chain legend
346
- plot_kwargs: dict, line plot arguments
347
- fill_kwargs: dict, density fill arguments
348
- rug_kwargs: dict, rug plot arguments
349
350
Returns:
351
- matplotlib axes: plot axes array
352
"""
353
354
def plot_posterior(trace, var_names=None, coords=None, figsize=None,
355
textsize=None, hdi_prob=0.94, multimodal=False,
356
skipna=False, ref_val=None, rope=None, point_estimate='mean',
357
round_to=2, credible_interval=None, **kwargs):
358
"""
359
Posterior distribution plots with summary statistics.
360
361
Shows marginal posterior distributions with credible intervals,
362
point estimates, and optional reference values or ROPE.
363
364
Parameters:
365
- trace: InferenceData, posterior samples
366
- var_names: list, variables to plot
367
- coords: dict, coordinate selections
368
- figsize: tuple, figure size
369
- textsize: float, text size for annotations
370
- hdi_prob: float, HDI probability
371
- multimodal: bool, detect and handle multimodal distributions
372
- skipna: bool, skip missing values
373
- ref_val: dict, reference values by variable
374
- rope: dict, region of practical equivalence bounds
375
- point_estimate: str, point estimate type ('mean', 'median', 'mode')
376
- round_to: int, decimal places for annotations
377
- credible_interval: float, deprecated alias for hdi_prob
378
379
Returns:
380
- matplotlib axes: plot axes
381
"""
382
383
def plot_forest(trace, var_names=None, coords=None, figsize=None,
384
textsize=None, ropestyle='top', ropes=None, credible_interval=0.94,
385
quartiles=True, r_hat=True, ess=True, combined=False,
386
colors='cycle', **kwargs):
387
"""
388
Forest plot showing parameter estimates with uncertainty intervals.
389
390
Horizontal plot displaying point estimates and credible intervals
391
for multiple parameters, useful for coefficient comparison.
392
393
Parameters:
394
- trace: InferenceData, posterior samples
395
- var_names: list, variables to include
396
- coords: dict, coordinate selections
397
- figsize: tuple, figure size
398
- textsize: float, text size
399
- ropestyle: str, ROPE display style ('top', 'bottom', None)
400
- ropes: dict, ROPE bounds by variable
401
- credible_interval: float, interval probability
402
- quartiles: bool, show quartile markers
403
- r_hat: bool, show R-hat values
404
- ess: bool, show effective sample size
405
- combined: bool, combine chains before plotting
406
- colors: str or list, color specification
407
408
Returns:
409
- matplotlib axes: plot axes
410
"""
411
412
def plot_autocorr(trace, var_names=None, coords=None, figsize=None,
413
textsize=None, max_lag=100, combined=False, **kwargs):
414
"""
415
Autocorrelation plots for assessing chain mixing.
416
417
Shows autocorrelation function to diagnose slow mixing
418
and estimate effective sample sizes visually.
419
420
Parameters:
421
- trace: InferenceData, posterior samples
422
- var_names: list, variables to plot
423
- coords: dict, coordinate selections
424
- figsize: tuple, figure size
425
- textsize: float, text size
426
- max_lag: int, maximum lag to plot
427
- combined: bool, combine chains
428
429
Returns:
430
- matplotlib axes: plot axes
431
"""
432
433
def plot_rank(trace, var_names=None, coords=None, figsize=None,
434
bins=20, kind='bars', **kwargs):
435
"""
436
Rank plots for MCMC diagnostics.
437
438
Shows rank statistics across chains to identify mixing
439
problems and between-chain differences.
440
441
Parameters:
442
- trace: InferenceData, posterior samples
443
- var_names: list, variables to plot
444
- coords: dict, coordinate selections
445
- figsize: tuple, figure size
446
- bins: int, number of rank bins
447
- kind: str, plot type ('bars', 'vlines')
448
449
Returns:
450
- matplotlib axes: plot axes
451
"""
452
453
def plot_energy(trace, figsize=None, **kwargs):
454
"""
455
Energy plot for HMC/NUTS diagnostics.
456
457
Compares energy distributions between tuning and sampling
458
phases to identify potential sampling problems.
459
460
Parameters:
461
- trace: InferenceData, posterior samples with energy info
462
- figsize: tuple, figure size
463
464
Returns:
465
- matplotlib axes: plot axes
466
"""
467
468
def plot_pair(trace, var_names=None, coords=None, figsize=None,
469
textsize=None, kind='scatter', gridsize='auto',
470
colorbar=True, divergences=False, **kwargs):
471
"""
472
Pairwise parameter plots showing correlations and structure.
473
474
Matrix of bivariate plots revealing posterior correlations,
475
multimodality, and geometric structure.
476
477
Parameters:
478
- trace: InferenceData, posterior samples
479
- var_names: list, variables to include
480
- coords: dict, coordinate selections
481
- figsize: tuple, figure size
482
- textsize: float, text size
483
- kind: str, plot type ('scatter', 'kde', 'hexbin')
484
- gridsize: int or 'auto', grid resolution for kde/hexbin
485
- colorbar: bool, show colorbar for density plots
486
- divergences: bool, highlight divergent samples
487
488
Returns:
489
- matplotlib axes: plot axes matrix
490
"""
491
492
def plot_parallel(trace, var_names=None, coords=None, figsize=None,
493
colornd='k', colord='r', shadend=0.025, **kwargs):
494
"""
495
Parallel coordinates plot for high-dimensional visualization.
496
497
Shows sample paths across multiple parameters to identify
498
correlations and outliers in high-dimensional posteriors.
499
500
Parameters:
501
- trace: InferenceData, posterior samples
502
- var_names: list, variables to include
503
- coords: dict, coordinate selections
504
- figsize: tuple, figure size
505
- colornd: color for non-divergent samples
506
- colord: color for divergent samples
507
- shadend: float, transparency for non-divergent samples
508
509
Returns:
510
- matplotlib axes: plot axes
511
"""
512
513
def plot_violin(trace, var_names=None, coords=None, figsize=None,
514
textsize=None, credible_interval=0.94, quartiles=True,
515
rug=False, **kwargs):
516
"""
517
Violin plots showing posterior distribution shapes.
518
519
Kernel density estimates with optional quartiles and
520
credible intervals for comparing parameter distributions.
521
522
Parameters:
523
- trace: InferenceData, posterior samples
524
- var_names: list, variables to plot
525
- coords: dict, coordinate selections
526
- figsize: tuple, figure size
527
- textsize: float, text size
528
- credible_interval: float, interval to mark
529
- quartiles: bool, show quartile lines
530
- rug: bool, add rug plot
531
532
Returns:
533
- matplotlib axes: plot axes
534
"""
535
536
def plot_kde(values, values2=None, cumulative=False, rug=False,
537
label=None, bw='scott', adaptive=False, extend=True,
538
gridsize=None, clip=None, alpha=0.7, **kwargs):
539
"""
540
Kernel density estimation plots.
541
542
Smooth density estimates for continuous distributions
543
with options for cumulative plots and comparisons.
544
545
Parameters:
546
- values: array, samples to plot
547
- values2: array, optional second sample for comparison
548
- cumulative: bool, plot cumulative density
549
- rug: bool, add rug plot
550
- label: str, plot label
551
- bw: str or float, bandwidth selection method
552
- adaptive: bool, use adaptive bandwidth
553
- extend: bool, extend domain beyond data range
554
- gridsize: int, evaluation grid size
555
- clip: tuple, domain bounds
556
- alpha: float, transparency
557
558
Returns:
559
- matplotlib axes: plot axes
560
"""
561
```
562
563
### Posterior Predictive Checking
564
565
Functions for model validation through posterior predictive distributions.
566
567
```python { .api }
568
def plot_ppc(trace, kind='kde', alpha=0.05, figsize=None, textsize=None,
569
data_pairs=None, var_names=None, coords=None, flatten=None,
570
flatten_pp=None, num_pp_samples=100, random_seed=None,
571
jitter=None, mean=True, observed=True, **kwargs):
572
"""
573
Posterior predictive check plots.
574
575
Compares observed data with posterior predictive samples
576
to assess model fit and identify systematic deviations.
577
578
Parameters:
579
- trace: InferenceData, with posterior_predictive group
580
- kind: str, plot type ('kde', 'cumulative', 'scatter')
581
- alpha: float, transparency for predictive samples
582
- figsize: tuple, figure size
583
- textsize: float, text size
584
- data_pairs: dict, observed data by variable name
585
- var_names: list, variables to plot
586
- coords: dict, coordinate selections
587
- flatten: list, dimensions to flatten
588
- flatten_pp: list, posterior predictive dimensions to flatten
589
- num_pp_samples: int, number of predictive samples to show
590
- random_seed: int, random seed for sample selection
591
- jitter: float, jitter amount for discrete data
592
- mean: bool, show predictive mean
593
- observed: bool, show observed data
594
595
Returns:
596
- matplotlib axes: plot axes
597
"""
598
599
def plot_loo_pit(idata, y=None, y_hat=None, log_weights=None,
600
ecdf=False, ecdf_fill=True, use_hdi=True,
601
credible_interval=0.99, figsize=None, **kwargs):
602
"""
603
Leave-one-out probability integral transform plots.
604
605
Diagnostic plots for posterior predictive calibration
606
using LOO-PIT values that should be uniform if well-calibrated.
607
608
Parameters:
609
- idata: InferenceData, inference results
610
- y: array, observed values
611
- y_hat: array, posterior predictive samples
612
- log_weights: array, importance weights
613
- ecdf: bool, overlay empirical CDF
614
- ecdf_fill: bool, fill ECDF confidence band
615
- use_hdi: bool, use HDI for confidence bands
616
- credible_interval: float, confidence level
617
- figsize: tuple, figure size
618
619
Returns:
620
- matplotlib axes: plot axes
621
"""
622
```
623
624
### Model Comparison Plots
625
626
Visualization for comparing multiple models.
627
628
```python { .api }
629
def plot_compare(comp_df, insample_dev=True, plot_ic_diff=True,
630
order_by_rank=True, figsize=None, textsize=None, **kwargs):
631
"""
632
Model comparison plot showing information criteria.
633
634
Visual comparison of models using WAIC/LOO with
635
standard errors and ranking information.
636
637
Parameters:
638
- comp_df: DataFrame, results from az.compare()
639
- insample_dev: bool, plot in-sample deviance
640
- plot_ic_diff: bool, plot differences from best model
641
- order_by_rank: bool, order models by rank
642
- figsize: tuple, figure size
643
- textsize: float, text size
644
645
Returns:
646
- matplotlib axes: plot axes
647
"""
648
649
def plot_elpd(comp_df, xlabels=False, figsize=None, textsize=None,
650
color='C0', **kwargs):
651
"""
652
Expected log predictive density comparison plot.
653
654
Parameters:
655
- comp_df: DataFrame, comparison results
656
- xlabels: bool, show x-axis labels
657
- figsize: tuple, figure size
658
- textsize: float, text size
659
- color: color specification
660
661
Returns:
662
- matplotlib axes: plot axes
663
"""
664
665
def plot_khat(khats, bins=None, figsize=None, ax=None, **kwargs):
666
"""
667
Pareto k diagnostic plot for LOO reliability.
668
669
Shows distribution of Pareto k values to assess
670
reliability of LOO approximation.
671
672
Parameters:
673
- khats: array, Pareto k values from loo()
674
- bins: int, histogram bins
675
- figsize: tuple, figure size
676
- ax: matplotlib axes, existing axes
677
678
Returns:
679
- matplotlib axes: plot axes
680
"""
681
```
682
683
## Usage Examples
684
685
### Comprehensive Model Diagnostics
686
687
```python
688
import pymc3 as pm
689
import numpy as np
690
import matplotlib.pyplot as plt
691
import arviz as az
692
693
# Example model and sampling
694
with pm.Model() as diagnostic_model:
695
mu = pm.Normal('mu', mu=0, sigma=10)
696
sigma = pm.HalfNormal('sigma', sigma=5)
697
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)
698
699
# Sample with multiple chains for diagnostics
700
trace = pm.sample(1000, tune=1000, chains=4,
701
target_accept=0.95, return_inferencedata=True)
702
703
# Convergence diagnostics
704
print("=== Convergence Diagnostics ===")
705
r_hat_values = az.r_hat(trace)
706
print("R-hat values:", r_hat_values)
707
708
ess_bulk = az.ess(trace, method='bulk')
709
ess_tail = az.ess(trace, method='tail')
710
print("Effective sample size (bulk):", ess_bulk)
711
print("Effective sample size (tail):", ess_tail)
712
713
mcse_values = az.mcse(trace)
714
print("Monte Carlo standard errors:", mcse_values)
715
716
# Comprehensive summary
717
summary_stats = az.summary(trace)
718
print("\n=== Posterior Summary ===")
719
print(summary_stats)
720
721
# Visual diagnostics
722
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
723
724
# Trace plots
725
az.plot_trace(trace, ax=axes[0])
726
727
# Rank plots
728
az.plot_rank(trace, ax=axes[1, 0])
729
730
# Autocorrelation
731
az.plot_autocorr(trace, max_lag=50, ax=axes[1, 1])
732
733
plt.tight_layout()
734
plt.show()
735
```
736
737
### Model Comparison Workflow
738
739
```python
740
# Multiple models for comparison
741
models = {}
742
traces = {}
743
744
# Model 1: Simple linear
745
with pm.Model() as model1:
746
alpha1 = pm.Normal('alpha', mu=0, sigma=10)
747
beta1 = pm.Normal('beta', mu=0, sigma=10)
748
sigma1 = pm.HalfNormal('sigma', sigma=1)
749
750
mu1 = alpha1 + beta1 * x_data
751
y_obs1 = pm.Normal('y_obs', mu=mu1, sigma=sigma1, observed=y_data)
752
753
trace1 = pm.sample(1000, tune=1000, return_inferencedata=True)
754
755
models['Linear'] = model1
756
traces['Linear'] = trace1
757
758
# Model 2: Quadratic
759
with pm.Model() as model2:
760
alpha2 = pm.Normal('alpha', mu=0, sigma=10)
761
beta1_2 = pm.Normal('beta1', mu=0, sigma=10)
762
beta2_2 = pm.Normal('beta2', mu=0, sigma=10)
763
sigma2 = pm.HalfNormal('sigma', sigma=1)
764
765
mu2 = alpha2 + beta1_2 * x_data + beta2_2 * x_data**2
766
y_obs2 = pm.Normal('y_obs', mu=mu2, sigma=sigma2, observed=y_data)
767
768
trace2 = pm.sample(1000, tune=1000, return_inferencedata=True)
769
770
models['Quadratic'] = model2
771
traces['Quadratic'] = trace2
772
773
# Model 3: Robust (Student's t)
774
with pm.Model() as model3:
775
alpha3 = pm.Normal('alpha', mu=0, sigma=10)
776
beta3 = pm.Normal('beta', mu=0, sigma=10)
777
sigma3 = pm.HalfNormal('sigma', sigma=1)
778
nu = pm.Gamma('nu', alpha=2, beta=0.1)
779
780
mu3 = alpha3 + beta3 * x_data
781
y_obs3 = pm.StudentT('y_obs', nu=nu, mu=mu3, sigma=sigma3, observed=y_data)
782
783
trace3 = pm.sample(1000, tune=1000, return_inferencedata=True)
784
785
models['Robust'] = model3
786
traces['Robust'] = trace3
787
788
# Compute information criteria
789
waic_results = {}
790
loo_results = {}
791
792
for name, trace in traces.items():
793
waic_results[name] = az.waic(trace)
794
loo_results[name] = az.loo(trace)
795
796
# Model comparison
797
comparison_waic = az.compare(traces, ic='waic')
798
comparison_loo = az.compare(traces, ic='loo')
799
800
print("=== Model Comparison (WAIC) ===")
801
print(comparison_waic)
802
803
print("\n=== Model Comparison (LOO) ===")
804
print(comparison_loo)
805
806
# Visualization
807
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
808
809
az.plot_compare(comparison_waic, ax=axes[0])
810
axes[0].set_title('WAIC Comparison')
811
812
az.plot_compare(comparison_loo, ax=axes[1])
813
axes[1].set_title('LOO Comparison')
814
815
plt.tight_layout()
816
plt.show()
817
818
# Check LOO reliability
819
for name, loo_result in loo_results.items():
820
k_values = loo_result.pareto_k.values.flatten()
821
n_high_k = np.sum(k_values > 0.7)
822
print(f"{name}: {n_high_k} observations with high Pareto k (> 0.7)")
823
```
824
825
### Posterior Predictive Checking
826
827
```python
828
# Generate posterior predictive samples
829
with models['Linear']: # Use best model from comparison
830
ppc = pm.sample_posterior_predictive(traces['Linear'], samples=100)
831
832
# Add posterior predictive to InferenceData
833
traces['Linear'].extend(az.from_pymc3(posterior_predictive=ppc))
834
835
# Posterior predictive checks
836
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
837
838
# Basic PPC plot
839
az.plot_ppc(traces['Linear'], ax=axes[0, 0], kind='kde')
840
axes[0, 0].set_title('Posterior Predictive Check (KDE)')
841
842
# Cumulative PPC
843
az.plot_ppc(traces['Linear'], ax=axes[0, 1], kind='cumulative')
844
axes[0, 1].set_title('Cumulative PPC')
845
846
# LOO-PIT for calibration
847
az.plot_loo_pit(traces['Linear'], ax=axes[1, 0])
848
axes[1, 0].set_title('LOO-PIT Calibration')
849
850
# Custom PPC statistics
851
def ppc_statistics(y_obs, y_pred):
852
"""Custom statistics for PPC."""
853
return {
854
'mean': np.mean(y_pred, axis=1),
855
'std': np.std(y_pred, axis=1),
856
'min': np.min(y_pred, axis=1),
857
'max': np.max(y_pred, axis=1)
858
}
859
860
# Compute statistics
861
obs_stats = ppc_statistics(y_data, y_data.reshape(1, -1))
862
pred_stats = ppc_statistics(y_data, ppc['y_obs'])
863
864
# Plot statistics comparison
865
statistics = ['mean', 'std', 'min', 'max']
866
obs_values = [obs_stats[stat][0] for stat in statistics]
867
pred_means = [np.mean(pred_stats[stat]) for stat in statistics]
868
pred_stds = [np.std(pred_stats[stat]) for stat in statistics]
869
870
x_pos = np.arange(len(statistics))
871
axes[1, 1].bar(x_pos - 0.2, obs_values, 0.4, label='Observed', alpha=0.7)
872
axes[1, 1].errorbar(x_pos + 0.2, pred_means, yerr=pred_stds,
873
fmt='o', label='Predicted', capsize=5)
874
axes[1, 1].set_xticks(x_pos)
875
axes[1, 1].set_xticklabels(statistics)
876
axes[1, 1].legend()
877
axes[1, 1].set_title('Summary Statistics Comparison')
878
879
plt.tight_layout()
880
plt.show()
881
```
882
883
### Advanced Visualization
884
885
```python
886
# Multi-parameter visualization
887
with pm.Model() as multivariate_model:
888
# Correlated parameters
889
theta = pm.MvNormal('theta',
890
mu=np.zeros(4),
891
cov=np.eye(4),
892
shape=4)
893
894
# Transform for identifiability
895
alpha = pm.Deterministic('alpha', theta[0])
896
beta = pm.Deterministic('beta', theta[1:])
897
898
# Model prediction
899
mu = alpha + pm.math.dot(beta, X_multi.T)
900
y_obs = pm.Normal('y_obs', mu=mu, sigma=0.5, observed=y_multi)
901
902
trace_mv = pm.sample(1000, tune=1000, return_inferencedata=True)
903
904
# Comprehensive visualization suite
905
fig = plt.figure(figsize=(16, 12))
906
907
# Trace plots
908
axes_trace = fig.add_subplot(3, 3, (1, 2))
909
az.plot_trace(trace_mv, var_names=['alpha'], ax=axes_trace)
910
911
# Posterior distributions
912
axes_post = fig.add_subplot(3, 3, 3)
913
az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes_post)
914
915
# Forest plot for coefficients
916
axes_forest = fig.add_subplot(3, 3, (4, 5))
917
az.plot_forest(trace_mv, var_names=['beta'], ax=axes_forest)
918
919
# Pairwise relationships
920
axes_pair = fig.add_subplot(3, 3, 6)
921
az.plot_pair(trace_mv, var_names=['alpha', 'beta'],
922
coords={'beta_dim_0': slice(0, 2)}, ax=axes_pair)
923
924
# Energy diagnostic
925
axes_energy = fig.add_subplot(3, 3, 7)
926
az.plot_energy(trace_mv, ax=axes_energy)
927
928
# Parallel coordinates
929
axes_parallel = fig.add_subplot(3, 3, 8)
930
az.plot_parallel(trace_mv, var_names=['alpha', 'beta'], ax=axes_parallel)
931
932
# Rank plot
933
axes_rank = fig.add_subplot(3, 3, 9)
934
az.plot_rank(trace_mv, var_names=['alpha'], ax=axes_rank)
935
936
plt.tight_layout()
937
plt.show()
938
```
939
940
### Custom Diagnostic Workflow
941
942
```python
943
# Custom convergence assessment
944
def comprehensive_diagnostics(trace, var_names=None):
945
"""Comprehensive diagnostic assessment."""
946
947
if var_names is None:
948
var_names = list(trace.posterior.data_vars)
949
950
diagnostics = {}
951
952
for var in var_names:
953
var_diagnostics = {}
954
955
# Basic convergence metrics
956
var_diagnostics['r_hat'] = float(az.r_hat(trace, var_names=[var])[var])
957
var_diagnostics['ess_bulk'] = float(az.ess(trace, var_names=[var], method='bulk')[var])
958
var_diagnostics['ess_tail'] = float(az.ess(trace, var_names=[var], method='tail')[var])
959
var_diagnostics['mcse_mean'] = float(az.mcse(trace, var_names=[var], method='mean')[var])
960
961
# Effective sample size ratios
962
n_samples = trace.posterior[var].size
963
var_diagnostics['ess_bulk_ratio'] = var_diagnostics['ess_bulk'] / n_samples
964
var_diagnostics['ess_tail_ratio'] = var_diagnostics['ess_tail'] / n_samples
965
966
# Convergence flags
967
var_diagnostics['converged'] = (
968
var_diagnostics['r_hat'] < 1.01 and
969
var_diagnostics['ess_bulk'] > 400 and
970
var_diagnostics['ess_tail'] > 400
971
)
972
973
diagnostics[var] = var_diagnostics
974
975
return diagnostics
976
977
# Run diagnostics
978
diag_results = comprehensive_diagnostics(trace_mv)
979
980
print("=== Comprehensive Diagnostics ===")
981
for var, diag in diag_results.items():
982
status = "✓ PASS" if diag['converged'] else "✗ FAIL"
983
print(f"\n{var} {status}")
984
print(f" R-hat: {diag['r_hat']:.4f}")
985
print(f" ESS bulk: {diag['ess_bulk']:.0f} ({diag['ess_bulk_ratio']:.2f})")
986
print(f" ESS tail: {diag['ess_tail']:.0f} ({diag['ess_tail_ratio']:.2f})")
987
print(f" MCSE mean: {diag['mcse_mean']:.4f}")
988
989
# Summary convergence status
990
all_converged = all(diag['converged'] for diag in diag_results.values())
991
print(f"\nOverall convergence: {'✓ PASS' if all_converged else '✗ FAIL'}")
992
993
if not all_converged:
994
print("\nRecommendations:")
995
print("- Increase number of samples")
996
print("- Check model parameterization")
997
print("- Consider different step size or sampler settings")
998
```
999
1000
### Publication-Ready Plots
1001
1002
```python
1003
# Create publication-quality figures
1004
plt.rcParams.update({
1005
'font.size': 12,
1006
'axes.labelsize': 14,
1007
'axes.titlesize': 16,
1008
'xtick.labelsize': 11,
1009
'ytick.labelsize': 11,
1010
'legend.fontsize': 12,
1011
'figure.titlesize': 18
1012
})
1013
1014
# Multi-panel figure for publication
1015
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
1016
fig.suptitle('Bayesian Linear Regression Analysis', fontsize=18, y=0.98)
1017
1018
# Panel A: Posterior distributions
1019
az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes[0, 0],
1020
hdi_prob=0.95, point_estimate='mean')
1021
axes[0, 0].set_title('A. Intercept Posterior')
1022
1023
# Panel B: Coefficient forest plot
1024
az.plot_forest(trace_mv, var_names=['beta'], ax=axes[0, 1],
1025
credible_interval=0.95, quartiles=False)
1026
axes[0, 1].set_title('B. Coefficient Estimates')
1027
1028
# Panel C: Model comparison
1029
az.plot_compare(comparison_waic, ax=axes[0, 2])
1030
axes[0, 2].set_title('C. Model Comparison (WAIC)')
1031
1032
# Panel D: Posterior predictive check
1033
az.plot_ppc(traces['Linear'], ax=axes[1, 0], kind='kde',
1034
alpha=0.1, num_pp_samples=50)
1035
axes[1, 0].set_title('D. Posterior Predictive Check')
1036
1037
# Panel E: Residual analysis (custom)
1038
# Extract posterior mean predictions
1039
post_pred = ppc['y_obs'].mean(axis=0)
1040
residuals = y_data - post_pred
1041
1042
axes[1, 1].scatter(post_pred, residuals, alpha=0.6)
1043
axes[1, 1].axhline(y=0, color='red', linestyle='--')
1044
axes[1, 1].set_xlabel('Fitted Values')
1045
axes[1, 1].set_ylabel('Residuals')
1046
axes[1, 1].set_title('E. Residual Analysis')
1047
1048
# Panel F: Convergence diagnostics summary
1049
convergence_summary = pd.DataFrame(diag_results).T[['r_hat', 'ess_bulk_ratio']]
1050
convergence_summary.plot(kind='bar', ax=axes[1, 2])
1051
axes[1, 2].set_title('F. Convergence Summary')
1052
axes[1, 2].set_ylabel('Diagnostic Value')
1053
axes[1, 2].tick_params(axis='x', rotation=45)
1054
1055
plt.tight_layout()
1056
plt.savefig('bayesian_analysis.png', dpi=300, bbox_inches='tight')
1057
plt.show()
1058
```