or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data-handling.mddistributions.mdgaussian-processes.mdglm.mdindex.mdmath-functions.mdmodeling.mdsampling.mdstats-plots.mdstep-methods.mdvariational.md

stats-plots.mddocs/

0

# Statistics and Plotting (ArviZ Integration)

1

2

PyMC3 integrates tightly with ArviZ for comprehensive Bayesian analysis, model diagnostics, and publication-quality visualizations. The stats and plots modules delegate to ArviZ while providing PyMC3-specific functionality and convenient aliases for common workflows.

3

4

## Capabilities

5

6

### Convergence Diagnostics

7

8

Functions for assessing MCMC convergence and sample quality through `pymc3.stats.*`.

9

10

```python { .api }

11

def r_hat(trace, var_names=None, method='rank'):

12

"""

13

Compute R-hat convergence diagnostic.

14

15

Measures between-chain and within-chain variance to assess

16

convergence across multiple MCMC chains. Values close to 1.0

17

indicate good convergence.

18

19

Parameters:

20

- trace: InferenceData or MultiTrace, posterior samples

21

- var_names: list, variables to analyze (all if None)

22

- method: str, computation method ('rank', 'split', 'folded')

23

24

Returns:

25

- dict or array: R-hat values by variable

26

27

Interpretation:

28

- R_hat < 1.01: Excellent convergence

29

- R_hat < 1.1: Good convergence

30

- R_hat > 1.1: Poor convergence, need more samples

31

"""

32

33

def ess(trace, var_names=None, method='bulk'):

34

"""

35

Compute effective sample size.

36

37

Estimates the number of independent samples, accounting for

38

autocorrelation in MCMC chains. Higher values indicate better

39

mixing and more efficient sampling.

40

41

Parameters:

42

- trace: InferenceData or MultiTrace, posterior samples

43

- var_names: list, variables to analyze

44

- method: str, ESS type ('bulk', 'tail', 'mean', 'sd', 'quantile')

45

46

Returns:

47

- dict or array: effective sample sizes

48

49

Guidelines:

50

- ESS > 400: Generally sufficient for posterior inference

51

- ESS > 100: Minimum for reasonable estimates

52

- ESS < 100: Increase sampling or improve model

53

"""

54

55

def mcse(trace, var_names=None, method='mean', prob=None):

56

"""

57

Monte Carlo standard error of estimates.

58

59

Measures uncertainty in posterior estimates due to finite

60

sampling, helping determine if more samples are needed.

61

62

Parameters:

63

- trace: InferenceData or MultiTrace, posterior samples

64

- var_names: list, variables to analyze

65

- method: str, estimate type ('mean', 'sd', 'quantile')

66

- prob: float, probability for quantile MCSE

67

68

Returns:

69

- dict: MCSE values by variable

70

"""

71

72

def geweke(trace, var_names=None, first=0.1, last=0.5, intervals=20):

73

"""

74

Geweke convergence diagnostic.

75

76

Compares means from early and late portions of chains to

77

assess within-chain convergence and stationarity.

78

79

Parameters:

80

- trace: InferenceData or MultiTrace, posterior samples

81

- var_names: list, variables to test

82

- first: float, fraction for early portion

83

- last: float, fraction for late portion

84

- intervals: int, number of test intervals

85

86

Returns:

87

- dict: Geweke statistics by variable

88

"""

89

```

90

91

### Model Comparison

92

93

Information criteria and cross-validation for Bayesian model selection.

94

95

```python { .api }

96

def compare(models, ic='waic', method='stacking', b_samples=1000,

97

alpha=1, seed=None, round_to=2):

98

"""

99

Compare multiple models using information criteria.

100

101

Ranks models by predictive performance using WAIC, LOO-CV,

102

or other criteria, with model weights and standard errors.

103

104

Parameters:

105

- models: dict, mapping model names to InferenceData objects

106

- ic: str, information criterion ('waic', 'loo')

107

- method: str, weighting method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')

108

- b_samples: int, samples for Bootstrap weighting

109

- alpha: float, concentration parameter for pseudo-BMA

110

- seed: int, random seed for reproducibility

111

- round_to: int, decimal places for results

112

113

Returns:

114

- DataFrame: model comparison results with ranks and weights

115

116

Columns:

117

- rank: model ranking (0 = best)

118

- elpd_*: expected log pointwise predictive density

119

- p_*: effective number of parameters

120

- d_*: difference from best model

121

- weight: model averaging weights

122

- se: standard error of differences

123

- dse: standard error of difference from best

124

"""

125

126

def waic(trace, model=None, pointwise=False, scale='deviance'):

127

"""

128

Watanabe-Akaike Information Criterion.

129

130

Estimates out-of-sample predictive performance using

131

within-sample log-likelihood with penalty for overfitting.

132

133

Parameters:

134

- trace: InferenceData or MultiTrace, posterior samples

135

- model: Model, model context (current if None)

136

- pointwise: bool, return pointwise WAIC values

137

- scale: str, return scale ('deviance' or 'log')

138

139

Returns:

140

- ELPDData: WAIC results with components and diagnostics

141

142

Components:

143

- elpd_waic: expected log pointwise predictive density

144

- p_waic: effective number of parameters

145

- waic: -2 * elpd_waic (lower is better)

146

- se: standard error of WAIC

147

"""

148

149

def loo(trace, model=None, pointwise=False, reff=None, scale='deviance'):

150

"""

151

Pareto Smoothed Importance Sampling Leave-One-Out Cross-Validation.

152

153

Estimates out-of-sample performance using leave-one-out

154

cross-validation approximated by importance sampling.

155

156

Parameters:

157

- trace: InferenceData or MultiTrace, posterior samples

158

- model: Model, model context

159

- pointwise: bool, return pointwise LOO values

160

- reff: array, relative effective sample sizes

161

- scale: str, return scale ('deviance' or 'log')

162

163

Returns:

164

- ELPDData: LOO-CV results with Pareto diagnostics

165

166

Diagnostics:

167

- Pareto k < 0.5: Good approximation

168

- Pareto k < 0.7: Okay approximation

169

- Pareto k > 0.7: Poor approximation, use exact CV

170

"""

171

172

def loo_pit(idata, y=None, y_hat=None, log_weights=None):

173

"""

174

Leave-one-out probability integral transform.

175

176

Calibration check for posterior predictive distributions

177

using LOO-PIT values that should be uniform if well-calibrated.

178

179

Parameters:

180

- idata: InferenceData, posterior and predictions

181

- y: array, observed data (from idata if None)

182

- y_hat: array, posterior predictive samples

183

- log_weights: array, importance sampling weights

184

185

Returns:

186

- array: LOO-PIT values for calibration assessment

187

"""

188

```

189

190

### Summary Statistics

191

192

Posterior summary and descriptive statistics.

193

194

```python { .api }

195

def summary(trace, var_names=None, stat_funcs=None, extend=True,

196

credible_interval=0.94, round_to=2, kind='stats'):

197

"""

198

Comprehensive posterior summary statistics.

199

200

Provides means, standard deviations, credible intervals,

201

and convergence diagnostics for all model parameters.

202

203

Parameters:

204

- trace: InferenceData or MultiTrace, posterior samples

205

- var_names: list, variables to summarize (all if None)

206

- stat_funcs: dict, custom summary functions

207

- extend: bool, include convergence diagnostics

208

- credible_interval: float, credible interval width

209

- round_to: int, decimal places

210

- kind: str, summary type ('stats', 'diagnostics')

211

212

Returns:

213

- DataFrame: comprehensive parameter summary

214

215

Columns:

216

- mean: posterior mean

217

- sd: posterior standard deviation

218

- hdi_3%/hdi_97%: highest density interval bounds

219

- mcse_mean: MCSE of mean

220

- mcse_sd: MCSE of standard deviation

221

- ess_bulk/ess_tail: effective sample sizes

222

- r_hat: R-hat convergence diagnostic

223

"""

224

225

def describe(trace, var_names=None, include_ci=True, ci_prob=0.94):

226

"""

227

Descriptive statistics for posterior distributions.

228

229

Parameters:

230

- trace: InferenceData or MultiTrace, posterior samples

231

- var_names: list, variables to describe

232

- include_ci: bool, include credible intervals

233

- ci_prob: float, credible interval probability

234

235

Returns:

236

- DataFrame: descriptive statistics

237

"""

238

239

def quantiles(x, qlist=(0.025, 0.25, 0.5, 0.75, 0.975)):

240

"""

241

Compute quantiles of posterior samples.

242

243

Parameters:

244

- x: array, samples

245

- qlist: tuple, quantile probabilities

246

247

Returns:

248

- dict: quantile values

249

"""

250

251

def hdi(x, credible_interval=0.94, circular=False):

252

"""

253

Highest Density Interval (HDI).

254

255

Computes the shortest interval containing specified

256

probability mass of the posterior distribution.

257

258

Parameters:

259

- x: array, posterior samples

260

- credible_interval: float, interval probability

261

- circular: bool, circular data (angles)

262

263

Returns:

264

- array: [lower_bound, upper_bound]

265

"""

266

```

267

268

### Posterior Analysis

269

270

Advanced posterior analysis and derived quantities.

271

272

```python { .api }

273

def autocorr(trace, var_names=None, max_lag=100):

274

"""

275

Autocorrelation function of MCMC chains.

276

277

Measures correlation between samples at different lags

278

to assess mixing and effective sample size.

279

280

Parameters:

281

- trace: InferenceData or MultiTrace, samples

282

- var_names: list, variables to analyze

283

- max_lag: int, maximum lag to compute

284

285

Returns:

286

- dict: autocorrelation functions by variable

287

"""

288

289

def make_ufunc(func, nin=1, nout=1, **kwargs):

290

"""

291

Create universal function for posterior analysis.

292

293

Converts regular functions into universal functions

294

that work efficiently on posterior sample arrays.

295

296

Parameters:

297

- func: callable, function to convert

298

- nin: int, number of inputs

299

- nout: int, number of outputs

300

- kwargs: additional ufunc arguments

301

302

Returns:

303

- ufunc: universal function

304

"""

305

306

def from_dict(posterior_dict, coords=None, dims=None):

307

"""

308

Create InferenceData from dictionary of arrays.

309

310

Parameters:

311

- posterior_dict: dict, posterior samples by variable

312

- coords: dict, coordinate values

313

- dims: dict, dimension names by variable

314

315

Returns:

316

- InferenceData: formatted inference data

317

"""

318

```

319

320

### Plotting Functions

321

322

Comprehensive visualization capabilities through ArviZ integration via `pymc3.plots.*`.

323

324

```python { .api }

325

def plot_trace(trace, var_names=None, coords=None, divergences='auto',

326

figsize=None, rug=False, lines=None, compact=True,

327

combined=False, legend=False, plot_kwargs=None,

328

fill_kwargs=None, rug_kwargs=None, **kwargs):

329

"""

330

Trace plots showing MCMC sampling paths and marginal distributions.

331

332

Essential diagnostic plot combining time series of samples

333

with marginal posterior distributions for visual convergence assessment.

334

335

Parameters:

336

- trace: InferenceData, posterior samples

337

- var_names: list, variables to plot

338

- coords: dict, coordinate slices for multidimensional variables

339

- divergences: str or bool, highlight divergent transitions

340

- figsize: tuple, figure size

341

- rug: bool, add rug plot to marginals

342

- lines: dict, reference lines to overlay

343

- compact: bool, compact layout

344

- combined: bool, combine all chains

345

- legend: bool, show chain legend

346

- plot_kwargs: dict, line plot arguments

347

- fill_kwargs: dict, density fill arguments

348

- rug_kwargs: dict, rug plot arguments

349

350

Returns:

351

- matplotlib axes: plot axes array

352

"""

353

354

def plot_posterior(trace, var_names=None, coords=None, figsize=None,

355

textsize=None, hdi_prob=0.94, multimodal=False,

356

skipna=False, ref_val=None, rope=None, point_estimate='mean',

357

round_to=2, credible_interval=None, **kwargs):

358

"""

359

Posterior distribution plots with summary statistics.

360

361

Shows marginal posterior distributions with credible intervals,

362

point estimates, and optional reference values or ROPE.

363

364

Parameters:

365

- trace: InferenceData, posterior samples

366

- var_names: list, variables to plot

367

- coords: dict, coordinate selections

368

- figsize: tuple, figure size

369

- textsize: float, text size for annotations

370

- hdi_prob: float, HDI probability

371

- multimodal: bool, detect and handle multimodal distributions

372

- skipna: bool, skip missing values

373

- ref_val: dict, reference values by variable

374

- rope: dict, region of practical equivalence bounds

375

- point_estimate: str, point estimate type ('mean', 'median', 'mode')

376

- round_to: int, decimal places for annotations

377

- credible_interval: float, deprecated alias for hdi_prob

378

379

Returns:

380

- matplotlib axes: plot axes

381

"""

382

383

def plot_forest(trace, var_names=None, coords=None, figsize=None,

384

textsize=None, ropestyle='top', ropes=None, credible_interval=0.94,

385

quartiles=True, r_hat=True, ess=True, combined=False,

386

colors='cycle', **kwargs):

387

"""

388

Forest plot showing parameter estimates with uncertainty intervals.

389

390

Horizontal plot displaying point estimates and credible intervals

391

for multiple parameters, useful for coefficient comparison.

392

393

Parameters:

394

- trace: InferenceData, posterior samples

395

- var_names: list, variables to include

396

- coords: dict, coordinate selections

397

- figsize: tuple, figure size

398

- textsize: float, text size

399

- ropestyle: str, ROPE display style ('top', 'bottom', None)

400

- ropes: dict, ROPE bounds by variable

401

- credible_interval: float, interval probability

402

- quartiles: bool, show quartile markers

403

- r_hat: bool, show R-hat values

404

- ess: bool, show effective sample size

405

- combined: bool, combine chains before plotting

406

- colors: str or list, color specification

407

408

Returns:

409

- matplotlib axes: plot axes

410

"""

411

412

def plot_autocorr(trace, var_names=None, coords=None, figsize=None,

413

textsize=None, max_lag=100, combined=False, **kwargs):

414

"""

415

Autocorrelation plots for assessing chain mixing.

416

417

Shows autocorrelation function to diagnose slow mixing

418

and estimate effective sample sizes visually.

419

420

Parameters:

421

- trace: InferenceData, posterior samples

422

- var_names: list, variables to plot

423

- coords: dict, coordinate selections

424

- figsize: tuple, figure size

425

- textsize: float, text size

426

- max_lag: int, maximum lag to plot

427

- combined: bool, combine chains

428

429

Returns:

430

- matplotlib axes: plot axes

431

"""

432

433

def plot_rank(trace, var_names=None, coords=None, figsize=None,

434

bins=20, kind='bars', **kwargs):

435

"""

436

Rank plots for MCMC diagnostics.

437

438

Shows rank statistics across chains to identify mixing

439

problems and between-chain differences.

440

441

Parameters:

442

- trace: InferenceData, posterior samples

443

- var_names: list, variables to plot

444

- coords: dict, coordinate selections

445

- figsize: tuple, figure size

446

- bins: int, number of rank bins

447

- kind: str, plot type ('bars', 'vlines')

448

449

Returns:

450

- matplotlib axes: plot axes

451

"""

452

453

def plot_energy(trace, figsize=None, **kwargs):

454

"""

455

Energy plot for HMC/NUTS diagnostics.

456

457

Compares energy distributions between tuning and sampling

458

phases to identify potential sampling problems.

459

460

Parameters:

461

- trace: InferenceData, posterior samples with energy info

462

- figsize: tuple, figure size

463

464

Returns:

465

- matplotlib axes: plot axes

466

"""

467

468

def plot_pair(trace, var_names=None, coords=None, figsize=None,

469

textsize=None, kind='scatter', gridsize='auto',

470

colorbar=True, divergences=False, **kwargs):

471

"""

472

Pairwise parameter plots showing correlations and structure.

473

474

Matrix of bivariate plots revealing posterior correlations,

475

multimodality, and geometric structure.

476

477

Parameters:

478

- trace: InferenceData, posterior samples

479

- var_names: list, variables to include

480

- coords: dict, coordinate selections

481

- figsize: tuple, figure size

482

- textsize: float, text size

483

- kind: str, plot type ('scatter', 'kde', 'hexbin')

484

- gridsize: int or 'auto', grid resolution for kde/hexbin

485

- colorbar: bool, show colorbar for density plots

486

- divergences: bool, highlight divergent samples

487

488

Returns:

489

- matplotlib axes: plot axes matrix

490

"""

491

492

def plot_parallel(trace, var_names=None, coords=None, figsize=None,

493

colornd='k', colord='r', shadend=0.025, **kwargs):

494

"""

495

Parallel coordinates plot for high-dimensional visualization.

496

497

Shows sample paths across multiple parameters to identify

498

correlations and outliers in high-dimensional posteriors.

499

500

Parameters:

501

- trace: InferenceData, posterior samples

502

- var_names: list, variables to include

503

- coords: dict, coordinate selections

504

- figsize: tuple, figure size

505

- colornd: color for non-divergent samples

506

- colord: color for divergent samples

507

- shadend: float, transparency for non-divergent samples

508

509

Returns:

510

- matplotlib axes: plot axes

511

"""

512

513

def plot_violin(trace, var_names=None, coords=None, figsize=None,

514

textsize=None, credible_interval=0.94, quartiles=True,

515

rug=False, **kwargs):

516

"""

517

Violin plots showing posterior distribution shapes.

518

519

Kernel density estimates with optional quartiles and

520

credible intervals for comparing parameter distributions.

521

522

Parameters:

523

- trace: InferenceData, posterior samples

524

- var_names: list, variables to plot

525

- coords: dict, coordinate selections

526

- figsize: tuple, figure size

527

- textsize: float, text size

528

- credible_interval: float, interval to mark

529

- quartiles: bool, show quartile lines

530

- rug: bool, add rug plot

531

532

Returns:

533

- matplotlib axes: plot axes

534

"""

535

536

def plot_kde(values, values2=None, cumulative=False, rug=False,

537

label=None, bw='scott', adaptive=False, extend=True,

538

gridsize=None, clip=None, alpha=0.7, **kwargs):

539

"""

540

Kernel density estimation plots.

541

542

Smooth density estimates for continuous distributions

543

with options for cumulative plots and comparisons.

544

545

Parameters:

546

- values: array, samples to plot

547

- values2: array, optional second sample for comparison

548

- cumulative: bool, plot cumulative density

549

- rug: bool, add rug plot

550

- label: str, plot label

551

- bw: str or float, bandwidth selection method

552

- adaptive: bool, use adaptive bandwidth

553

- extend: bool, extend domain beyond data range

554

- gridsize: int, evaluation grid size

555

- clip: tuple, domain bounds

556

- alpha: float, transparency

557

558

Returns:

559

- matplotlib axes: plot axes

560

"""

561

```

562

563

### Posterior Predictive Checking

564

565

Functions for model validation through posterior predictive distributions.

566

567

```python { .api }

568

def plot_ppc(trace, kind='kde', alpha=0.05, figsize=None, textsize=None,

569

data_pairs=None, var_names=None, coords=None, flatten=None,

570

flatten_pp=None, num_pp_samples=100, random_seed=None,

571

jitter=None, mean=True, observed=True, **kwargs):

572

"""

573

Posterior predictive check plots.

574

575

Compares observed data with posterior predictive samples

576

to assess model fit and identify systematic deviations.

577

578

Parameters:

579

- trace: InferenceData, with posterior_predictive group

580

- kind: str, plot type ('kde', 'cumulative', 'scatter')

581

- alpha: float, transparency for predictive samples

582

- figsize: tuple, figure size

583

- textsize: float, text size

584

- data_pairs: dict, observed data by variable name

585

- var_names: list, variables to plot

586

- coords: dict, coordinate selections

587

- flatten: list, dimensions to flatten

588

- flatten_pp: list, posterior predictive dimensions to flatten

589

- num_pp_samples: int, number of predictive samples to show

590

- random_seed: int, random seed for sample selection

591

- jitter: float, jitter amount for discrete data

592

- mean: bool, show predictive mean

593

- observed: bool, show observed data

594

595

Returns:

596

- matplotlib axes: plot axes

597

"""

598

599

def plot_loo_pit(idata, y=None, y_hat=None, log_weights=None,

600

ecdf=False, ecdf_fill=True, use_hdi=True,

601

credible_interval=0.99, figsize=None, **kwargs):

602

"""

603

Leave-one-out probability integral transform plots.

604

605

Diagnostic plots for posterior predictive calibration

606

using LOO-PIT values that should be uniform if well-calibrated.

607

608

Parameters:

609

- idata: InferenceData, inference results

610

- y: array, observed values

611

- y_hat: array, posterior predictive samples

612

- log_weights: array, importance weights

613

- ecdf: bool, overlay empirical CDF

614

- ecdf_fill: bool, fill ECDF confidence band

615

- use_hdi: bool, use HDI for confidence bands

616

- credible_interval: float, confidence level

617

- figsize: tuple, figure size

618

619

Returns:

620

- matplotlib axes: plot axes

621

"""

622

```

623

624

### Model Comparison Plots

625

626

Visualization for comparing multiple models.

627

628

```python { .api }

629

def plot_compare(comp_df, insample_dev=True, plot_ic_diff=True,

630

order_by_rank=True, figsize=None, textsize=None, **kwargs):

631

"""

632

Model comparison plot showing information criteria.

633

634

Visual comparison of models using WAIC/LOO with

635

standard errors and ranking information.

636

637

Parameters:

638

- comp_df: DataFrame, results from az.compare()

639

- insample_dev: bool, plot in-sample deviance

640

- plot_ic_diff: bool, plot differences from best model

641

- order_by_rank: bool, order models by rank

642

- figsize: tuple, figure size

643

- textsize: float, text size

644

645

Returns:

646

- matplotlib axes: plot axes

647

"""

648

649

def plot_elpd(comp_df, xlabels=False, figsize=None, textsize=None,

650

color='C0', **kwargs):

651

"""

652

Expected log predictive density comparison plot.

653

654

Parameters:

655

- comp_df: DataFrame, comparison results

656

- xlabels: bool, show x-axis labels

657

- figsize: tuple, figure size

658

- textsize: float, text size

659

- color: color specification

660

661

Returns:

662

- matplotlib axes: plot axes

663

"""

664

665

def plot_khat(khats, bins=None, figsize=None, ax=None, **kwargs):

666

"""

667

Pareto k diagnostic plot for LOO reliability.

668

669

Shows distribution of Pareto k values to assess

670

reliability of LOO approximation.

671

672

Parameters:

673

- khats: array, Pareto k values from loo()

674

- bins: int, histogram bins

675

- figsize: tuple, figure size

676

- ax: matplotlib axes, existing axes

677

678

Returns:

679

- matplotlib axes: plot axes

680

"""

681

```

682

683

## Usage Examples

684

685

### Comprehensive Model Diagnostics

686

687

```python

688

import pymc3 as pm

689

import numpy as np

690

import matplotlib.pyplot as plt

691

import arviz as az

692

693

# Example model and sampling

694

with pm.Model() as diagnostic_model:

695

mu = pm.Normal('mu', mu=0, sigma=10)

696

sigma = pm.HalfNormal('sigma', sigma=5)

697

y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)

698

699

# Sample with multiple chains for diagnostics

700

trace = pm.sample(1000, tune=1000, chains=4,

701

target_accept=0.95, return_inferencedata=True)

702

703

# Convergence diagnostics

704

print("=== Convergence Diagnostics ===")

705

r_hat_values = az.r_hat(trace)

706

print("R-hat values:", r_hat_values)

707

708

ess_bulk = az.ess(trace, method='bulk')

709

ess_tail = az.ess(trace, method='tail')

710

print("Effective sample size (bulk):", ess_bulk)

711

print("Effective sample size (tail):", ess_tail)

712

713

mcse_values = az.mcse(trace)

714

print("Monte Carlo standard errors:", mcse_values)

715

716

# Comprehensive summary

717

summary_stats = az.summary(trace)

718

print("\n=== Posterior Summary ===")

719

print(summary_stats)

720

721

# Visual diagnostics

722

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

723

724

# Trace plots

725

az.plot_trace(trace, ax=axes[0])

726

727

# Rank plots

728

az.plot_rank(trace, ax=axes[1, 0])

729

730

# Autocorrelation

731

az.plot_autocorr(trace, max_lag=50, ax=axes[1, 1])

732

733

plt.tight_layout()

734

plt.show()

735

```

736

737

### Model Comparison Workflow

738

739

```python

740

# Multiple models for comparison

741

models = {}

742

traces = {}

743

744

# Model 1: Simple linear

745

with pm.Model() as model1:

746

alpha1 = pm.Normal('alpha', mu=0, sigma=10)

747

beta1 = pm.Normal('beta', mu=0, sigma=10)

748

sigma1 = pm.HalfNormal('sigma', sigma=1)

749

750

mu1 = alpha1 + beta1 * x_data

751

y_obs1 = pm.Normal('y_obs', mu=mu1, sigma=sigma1, observed=y_data)

752

753

trace1 = pm.sample(1000, tune=1000, return_inferencedata=True)

754

755

models['Linear'] = model1

756

traces['Linear'] = trace1

757

758

# Model 2: Quadratic

759

with pm.Model() as model2:

760

alpha2 = pm.Normal('alpha', mu=0, sigma=10)

761

beta1_2 = pm.Normal('beta1', mu=0, sigma=10)

762

beta2_2 = pm.Normal('beta2', mu=0, sigma=10)

763

sigma2 = pm.HalfNormal('sigma', sigma=1)

764

765

mu2 = alpha2 + beta1_2 * x_data + beta2_2 * x_data**2

766

y_obs2 = pm.Normal('y_obs', mu=mu2, sigma=sigma2, observed=y_data)

767

768

trace2 = pm.sample(1000, tune=1000, return_inferencedata=True)

769

770

models['Quadratic'] = model2

771

traces['Quadratic'] = trace2

772

773

# Model 3: Robust (Student's t)

774

with pm.Model() as model3:

775

alpha3 = pm.Normal('alpha', mu=0, sigma=10)

776

beta3 = pm.Normal('beta', mu=0, sigma=10)

777

sigma3 = pm.HalfNormal('sigma', sigma=1)

778

nu = pm.Gamma('nu', alpha=2, beta=0.1)

779

780

mu3 = alpha3 + beta3 * x_data

781

y_obs3 = pm.StudentT('y_obs', nu=nu, mu=mu3, sigma=sigma3, observed=y_data)

782

783

trace3 = pm.sample(1000, tune=1000, return_inferencedata=True)

784

785

models['Robust'] = model3

786

traces['Robust'] = trace3

787

788

# Compute information criteria

789

waic_results = {}

790

loo_results = {}

791

792

for name, trace in traces.items():

793

waic_results[name] = az.waic(trace)

794

loo_results[name] = az.loo(trace)

795

796

# Model comparison

797

comparison_waic = az.compare(traces, ic='waic')

798

comparison_loo = az.compare(traces, ic='loo')

799

800

print("=== Model Comparison (WAIC) ===")

801

print(comparison_waic)

802

803

print("\n=== Model Comparison (LOO) ===")

804

print(comparison_loo)

805

806

# Visualization

807

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

808

809

az.plot_compare(comparison_waic, ax=axes[0])

810

axes[0].set_title('WAIC Comparison')

811

812

az.plot_compare(comparison_loo, ax=axes[1])

813

axes[1].set_title('LOO Comparison')

814

815

plt.tight_layout()

816

plt.show()

817

818

# Check LOO reliability

819

for name, loo_result in loo_results.items():

820

k_values = loo_result.pareto_k.values.flatten()

821

n_high_k = np.sum(k_values > 0.7)

822

print(f"{name}: {n_high_k} observations with high Pareto k (> 0.7)")

823

```

824

825

### Posterior Predictive Checking

826

827

```python

828

# Generate posterior predictive samples

829

with models['Linear']: # Use best model from comparison

830

ppc = pm.sample_posterior_predictive(traces['Linear'], samples=100)

831

832

# Add posterior predictive to InferenceData

833

traces['Linear'].extend(az.from_pymc3(posterior_predictive=ppc))

834

835

# Posterior predictive checks

836

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

837

838

# Basic PPC plot

839

az.plot_ppc(traces['Linear'], ax=axes[0, 0], kind='kde')

840

axes[0, 0].set_title('Posterior Predictive Check (KDE)')

841

842

# Cumulative PPC

843

az.plot_ppc(traces['Linear'], ax=axes[0, 1], kind='cumulative')

844

axes[0, 1].set_title('Cumulative PPC')

845

846

# LOO-PIT for calibration

847

az.plot_loo_pit(traces['Linear'], ax=axes[1, 0])

848

axes[1, 0].set_title('LOO-PIT Calibration')

849

850

# Custom PPC statistics

851

def ppc_statistics(y_obs, y_pred):

852

"""Custom statistics for PPC."""

853

return {

854

'mean': np.mean(y_pred, axis=1),

855

'std': np.std(y_pred, axis=1),

856

'min': np.min(y_pred, axis=1),

857

'max': np.max(y_pred, axis=1)

858

}

859

860

# Compute statistics

861

obs_stats = ppc_statistics(y_data, y_data.reshape(1, -1))

862

pred_stats = ppc_statistics(y_data, ppc['y_obs'])

863

864

# Plot statistics comparison

865

statistics = ['mean', 'std', 'min', 'max']

866

obs_values = [obs_stats[stat][0] for stat in statistics]

867

pred_means = [np.mean(pred_stats[stat]) for stat in statistics]

868

pred_stds = [np.std(pred_stats[stat]) for stat in statistics]

869

870

x_pos = np.arange(len(statistics))

871

axes[1, 1].bar(x_pos - 0.2, obs_values, 0.4, label='Observed', alpha=0.7)

872

axes[1, 1].errorbar(x_pos + 0.2, pred_means, yerr=pred_stds,

873

fmt='o', label='Predicted', capsize=5)

874

axes[1, 1].set_xticks(x_pos)

875

axes[1, 1].set_xticklabels(statistics)

876

axes[1, 1].legend()

877

axes[1, 1].set_title('Summary Statistics Comparison')

878

879

plt.tight_layout()

880

plt.show()

881

```

882

883

### Advanced Visualization

884

885

```python

886

# Multi-parameter visualization

887

with pm.Model() as multivariate_model:

888

# Correlated parameters

889

theta = pm.MvNormal('theta',

890

mu=np.zeros(4),

891

cov=np.eye(4),

892

shape=4)

893

894

# Transform for identifiability

895

alpha = pm.Deterministic('alpha', theta[0])

896

beta = pm.Deterministic('beta', theta[1:])

897

898

# Model prediction

899

mu = alpha + pm.math.dot(beta, X_multi.T)

900

y_obs = pm.Normal('y_obs', mu=mu, sigma=0.5, observed=y_multi)

901

902

trace_mv = pm.sample(1000, tune=1000, return_inferencedata=True)

903

904

# Comprehensive visualization suite

905

fig = plt.figure(figsize=(16, 12))

906

907

# Trace plots

908

axes_trace = fig.add_subplot(3, 3, (1, 2))

909

az.plot_trace(trace_mv, var_names=['alpha'], ax=axes_trace)

910

911

# Posterior distributions

912

axes_post = fig.add_subplot(3, 3, 3)

913

az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes_post)

914

915

# Forest plot for coefficients

916

axes_forest = fig.add_subplot(3, 3, (4, 5))

917

az.plot_forest(trace_mv, var_names=['beta'], ax=axes_forest)

918

919

# Pairwise relationships

920

axes_pair = fig.add_subplot(3, 3, 6)

921

az.plot_pair(trace_mv, var_names=['alpha', 'beta'],

922

coords={'beta_dim_0': slice(0, 2)}, ax=axes_pair)

923

924

# Energy diagnostic

925

axes_energy = fig.add_subplot(3, 3, 7)

926

az.plot_energy(trace_mv, ax=axes_energy)

927

928

# Parallel coordinates

929

axes_parallel = fig.add_subplot(3, 3, 8)

930

az.plot_parallel(trace_mv, var_names=['alpha', 'beta'], ax=axes_parallel)

931

932

# Rank plot

933

axes_rank = fig.add_subplot(3, 3, 9)

934

az.plot_rank(trace_mv, var_names=['alpha'], ax=axes_rank)

935

936

plt.tight_layout()

937

plt.show()

938

```

939

940

### Custom Diagnostic Workflow

941

942

```python

943

# Custom convergence assessment

944

def comprehensive_diagnostics(trace, var_names=None):

945

"""Comprehensive diagnostic assessment."""

946

947

if var_names is None:

948

var_names = list(trace.posterior.data_vars)

949

950

diagnostics = {}

951

952

for var in var_names:

953

var_diagnostics = {}

954

955

# Basic convergence metrics

956

var_diagnostics['r_hat'] = float(az.r_hat(trace, var_names=[var])[var])

957

var_diagnostics['ess_bulk'] = float(az.ess(trace, var_names=[var], method='bulk')[var])

958

var_diagnostics['ess_tail'] = float(az.ess(trace, var_names=[var], method='tail')[var])

959

var_diagnostics['mcse_mean'] = float(az.mcse(trace, var_names=[var], method='mean')[var])

960

961

# Effective sample size ratios

962

n_samples = trace.posterior[var].size

963

var_diagnostics['ess_bulk_ratio'] = var_diagnostics['ess_bulk'] / n_samples

964

var_diagnostics['ess_tail_ratio'] = var_diagnostics['ess_tail'] / n_samples

965

966

# Convergence flags

967

var_diagnostics['converged'] = (

968

var_diagnostics['r_hat'] < 1.01 and

969

var_diagnostics['ess_bulk'] > 400 and

970

var_diagnostics['ess_tail'] > 400

971

)

972

973

diagnostics[var] = var_diagnostics

974

975

return diagnostics

976

977

# Run diagnostics

978

diag_results = comprehensive_diagnostics(trace_mv)

979

980

print("=== Comprehensive Diagnostics ===")

981

for var, diag in diag_results.items():

982

status = "✓ PASS" if diag['converged'] else "✗ FAIL"

983

print(f"\n{var} {status}")

984

print(f" R-hat: {diag['r_hat']:.4f}")

985

print(f" ESS bulk: {diag['ess_bulk']:.0f} ({diag['ess_bulk_ratio']:.2f})")

986

print(f" ESS tail: {diag['ess_tail']:.0f} ({diag['ess_tail_ratio']:.2f})")

987

print(f" MCSE mean: {diag['mcse_mean']:.4f}")

988

989

# Summary convergence status

990

all_converged = all(diag['converged'] for diag in diag_results.values())

991

print(f"\nOverall convergence: {'✓ PASS' if all_converged else '✗ FAIL'}")

992

993

if not all_converged:

994

print("\nRecommendations:")

995

print("- Increase number of samples")

996

print("- Check model parameterization")

997

print("- Consider different step size or sampler settings")

998

```

999

1000

### Publication-Ready Plots

1001

1002

```python

1003

# Create publication-quality figures

1004

plt.rcParams.update({

1005

'font.size': 12,

1006

'axes.labelsize': 14,

1007

'axes.titlesize': 16,

1008

'xtick.labelsize': 11,

1009

'ytick.labelsize': 11,

1010

'legend.fontsize': 12,

1011

'figure.titlesize': 18

1012

})

1013

1014

# Multi-panel figure for publication

1015

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

1016

fig.suptitle('Bayesian Linear Regression Analysis', fontsize=18, y=0.98)

1017

1018

# Panel A: Posterior distributions

1019

az.plot_posterior(trace_mv, var_names=['alpha'], ax=axes[0, 0],

1020

hdi_prob=0.95, point_estimate='mean')

1021

axes[0, 0].set_title('A. Intercept Posterior')

1022

1023

# Panel B: Coefficient forest plot

1024

az.plot_forest(trace_mv, var_names=['beta'], ax=axes[0, 1],

1025

credible_interval=0.95, quartiles=False)

1026

axes[0, 1].set_title('B. Coefficient Estimates')

1027

1028

# Panel C: Model comparison

1029

az.plot_compare(comparison_waic, ax=axes[0, 2])

1030

axes[0, 2].set_title('C. Model Comparison (WAIC)')

1031

1032

# Panel D: Posterior predictive check

1033

az.plot_ppc(traces['Linear'], ax=axes[1, 0], kind='kde',

1034

alpha=0.1, num_pp_samples=50)

1035

axes[1, 0].set_title('D. Posterior Predictive Check')

1036

1037

# Panel E: Residual analysis (custom)

1038

# Extract posterior mean predictions

1039

post_pred = ppc['y_obs'].mean(axis=0)

1040

residuals = y_data - post_pred

1041

1042

axes[1, 1].scatter(post_pred, residuals, alpha=0.6)

1043

axes[1, 1].axhline(y=0, color='red', linestyle='--')

1044

axes[1, 1].set_xlabel('Fitted Values')

1045

axes[1, 1].set_ylabel('Residuals')

1046

axes[1, 1].set_title('E. Residual Analysis')

1047

1048

# Panel F: Convergence diagnostics summary

1049

convergence_summary = pd.DataFrame(diag_results).T[['r_hat', 'ess_bulk_ratio']]

1050

convergence_summary.plot(kind='bar', ax=axes[1, 2])

1051

axes[1, 2].set_title('F. Convergence Summary')

1052

axes[1, 2].set_ylabel('Diagnostic Value')

1053

axes[1, 2].tick_params(axis='x', rotation=45)

1054

1055

plt.tight_layout()

1056

plt.savefig('bayesian_analysis.png', dpi=300, bbox_inches='tight')

1057

plt.show()

1058

```