or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data.mddistributions.mdgp.mdindex.mdmath.mdmodel.mdode.mdsampling.mdstats.mdvariational.md

stats.mddocs/

0

# PyMC Statistics and Diagnostics

1

2

PyMC provides comprehensive statistical functions and convergence diagnostics for Bayesian analysis, primarily through integration with ArviZ. The library offers tools for model validation, convergence assessment, and posterior analysis.

3

4

## Convergence Diagnostics

5

6

PyMC exposes key diagnostic functions from ArviZ for assessing MCMC convergence:

7

8

### R-hat Statistic

9

10

```python { .api }

11

def rhat(data, var_names=None, method='rank', dask_kwargs=None):

12

"""

13

Compute R-hat convergence diagnostic (Gelman-Rubin statistic).

14

15

Parameters:

16

- data: InferenceData object or trace

17

- var_names (list, optional): Variables to analyze

18

- method (str): Method for computation ('rank', 'split', 'folded')

19

- dask_kwargs (dict, optional): Dask computation options

20

21

Returns:

22

- rhat_values: R-hat statistics for each variable

23

"""

24

25

import pymc as pm

26

27

# Compute R-hat for all variables

28

with pm.Model() as model:

29

# Model definition and sampling...

30

trace = pm.sample()

31

32

rhat_stats = pm.rhat(trace)

33

print("R-hat diagnostics:")

34

for var, rhat_val in rhat_stats.items():

35

print(f" {var}: {rhat_val:.4f}")

36

37

# R-hat for specific variables only

38

rhat_subset = pm.rhat(trace, var_names=['alpha', 'beta'])

39

40

# Check convergence (R-hat should be < 1.01)

41

converged = all(rhat_val < 1.01 for rhat_val in rhat_stats.values())

42

```

43

44

### Effective Sample Size

45

46

```python { .api }

47

def effective_sample_size(data, var_names=None, method='bulk',

48

relative=False, dask_kwargs=None):

49

"""

50

Compute effective sample size (ESS).

51

52

Parameters:

53

- data: InferenceData object or trace

54

- var_names (list, optional): Variables to analyze

55

- method (str): ESS method ('bulk', 'tail', 'quantile', 'mean', 'sd')

56

- relative (bool): Return relative ESS (ESS/N)

57

- dask_kwargs (dict, optional): Dask computation options

58

59

Returns:

60

- ess_values: Effective sample size for each variable

61

"""

62

63

# Bulk ESS (measures efficiency in central posterior)

64

bulk_ess = pm.ess(trace, method='bulk')

65

66

# Tail ESS (measures efficiency in posterior tails)

67

tail_ess = pm.ess(trace, method='tail')

68

69

# Relative ESS (as fraction of total samples)

70

rel_ess = pm.ess(trace, relative=True)

71

72

print("Effective Sample Size (bulk):")

73

for var, ess_val in bulk_ess.items():

74

print(f" {var}: {ess_val:.0f}")

75

76

# Check adequacy (ESS should be > 400 for reliable inference)

77

adequate_ess = all(ess_val > 400 for ess_val in bulk_ess.values())

78

```

79

80

### Monte Carlo Standard Error

81

82

```python { .api }

83

def mcse(data, var_names=None, method='mean', dask_kwargs=None):

84

"""

85

Compute Monte Carlo standard error.

86

87

Parameters:

88

- data: InferenceData object or trace

89

- var_names (list, optional): Variables to analyze

90

- method (str): Statistic to compute MCSE for ('mean', 'sd', 'quantile')

91

- dask_kwargs (dict, optional): Dask computation options

92

93

Returns:

94

- mcse_values: Monte Carlo standard errors

95

"""

96

97

# MCSE for posterior means

98

mcse_mean = pm.mcse(trace, method='mean')

99

100

# MCSE for posterior standard deviations

101

mcse_sd = pm.mcse(trace, method='sd')

102

103

# MCSE for quantiles

104

mcse_quantile = pm.mcse(trace, method='quantile')

105

106

print("Monte Carlo Standard Error (mean):")

107

for var, mcse_val in mcse_mean.items():

108

print(f" {var}: {mcse_val:.6f}")

109

```

110

111

## Model Comparison

112

113

### Leave-One-Out Cross-Validation

114

115

```python { .api }

116

def loo(data, var_name=None, reff=None, scale=None, pointwise=False,

117

dask_kwargs=None):

118

"""

119

Compute leave-one-out (LOO) cross-validation using Pareto smoothed importance sampling.

120

121

Parameters:

122

- data: InferenceData object with log_likelihood group

123

- var_name (str, optional): Variable name for likelihood

124

- reff (array, optional): Relative effective sample size

125

- scale (str): Scale for IC ('log', 'negative_log', 'deviance')

126

- pointwise (bool): Return pointwise LOO values

127

- dask_kwargs (dict, optional): Dask computation options

128

129

Returns:

130

- loo_result: LOO-CV results with ELPD, SE, and diagnostics

131

"""

132

133

# Compute log-likelihood for LOO

134

with pm.Model() as model:

135

# Model definition...

136

trace = pm.sample()

137

138

# Compute log-likelihood

139

log_likelihood = pm.compute_log_likelihood(trace, model=model)

140

141

# LOO cross-validation

142

loo_result = pm.loo(trace)

143

print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")

144

print(f"LOO IC: {loo_result.loo:.2f}")

145

print(f"p_loo (effective parameters): {loo_result.p_loo:.2f}")

146

147

# Check Pareto k diagnostic

148

high_k = (loo_result.pareto_k > 0.7).sum()

149

if high_k > 0:

150

print(f"Warning: {high_k} observations have high Pareto k values")

151

```

152

153

### Watanabe-Akaike Information Criterion

154

155

```python { .api }

156

def waic(data, var_name=None, scale=None, pointwise=False, dask_kwargs=None):

157

"""

158

Compute Watanabe-Akaike Information Criterion (WAIC).

159

160

Parameters:

161

- data: InferenceData object with log_likelihood group

162

- var_name (str, optional): Variable name for likelihood

163

- scale (str): Scale for IC ('log', 'negative_log', 'deviance')

164

- pointwise (bool): Return pointwise WAIC values

165

- dask_kwargs (dict, optional): Dask computation options

166

167

Returns:

168

- waic_result: WAIC results with ELPD, SE, and effective parameters

169

"""

170

171

# WAIC computation

172

waic_result = pm.waic(trace)

173

print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")

174

print(f"WAIC: {waic_result.waic:.2f}")

175

print(f"p_waic (effective parameters): {waic_result.p_waic:.2f}")

176

177

# Pointwise WAIC for outlier detection

178

waic_pointwise = pm.waic(trace, pointwise=True)

179

outlier_threshold = waic_pointwise.waic_i.mean() + 2 * waic_pointwise.waic_i.std()

180

outliers = waic_pointwise.waic_i > outlier_threshold

181

print(f"Potential outliers: {outliers.sum()} observations")

182

```

183

184

### Model Comparison Framework

185

186

```python { .api }

187

def compare(compare_dict, ic=None, method='stacking', b_samples=1000,

188

alpha=0.05, seed=None, scale=None):

189

"""

190

Compare models using information criteria.

191

192

Parameters:

193

- compare_dict (dict): Dictionary of {model_name: InferenceData}

194

- ic (str): Information criterion ('loo', 'waic')

195

- method (str): Comparison method ('stacking', 'BB-pseudo-BMA', 'pseudo-BMA')

196

- b_samples (int): Bootstrap samples for SE estimation

197

- alpha (float): Significance level for intervals

198

- seed (int): Random seed

199

- scale (str): Scale for IC reporting

200

201

Returns:

202

- comparison_df: DataFrame with model comparison results

203

"""

204

205

# Compare multiple models

206

models = {

207

'linear': linear_trace,

208

'quadratic': quadratic_trace,

209

'cubic': cubic_trace

210

}

211

212

comparison = pm.compare(models, ic='loo')

213

print("Model Comparison (LOO):")

214

print(comparison)

215

216

# Model weights from stacking

217

print("\nModel weights:")

218

for model, weight in zip(comparison.index, comparison.weight):

219

print(f" {model}: {weight:.3f}")

220

221

# Automatically select best model

222

best_model = comparison.index[0] # First row is best

223

print(f"\nBest model: {best_model}")

224

```

225

226

## Log-Likelihood and Prior Computation

227

228

### Log-Likelihood Calculation

229

230

```python { .api }

231

def compute_log_likelihood(idata=None, *, model=None, var_names=None,

232

extend_inferencedata=True, progressbar=True):

233

"""

234

Compute pointwise log-likelihood values.

235

236

Parameters:

237

- idata: InferenceData object with posterior samples

238

- model: PyMC model (default: current context)

239

- var_names (list, optional): Observed variables to compute likelihood for

240

- extend_inferencedata (bool): Add results to InferenceData

241

- progressbar (bool): Show progress bar

242

243

Returns:

244

- log_likelihood: Log-likelihood values for each observation and posterior sample

245

"""

246

247

with pm.Model() as model:

248

# Model with likelihood...

249

trace = pm.sample()

250

251

# Compute log-likelihood

252

log_lik = pm.compute_log_likelihood(trace, model=model)

253

254

# Access log-likelihood values

255

ll_values = trace.log_likelihood # Added to InferenceData

256

print(f"Log-likelihood shape: {ll_values['y_obs'].shape}") # (chains, draws, observations)

257

258

# Total log-likelihood per sample

259

total_ll = ll_values['y_obs'].sum(dim='y_obs_dim_0')

260

print(f"Total log-likelihood range: {total_ll.min():.2f} to {total_ll.max():.2f}")

261

```

262

263

### Log-Prior Calculation

264

265

```python { .api }

266

def compute_log_prior(idata=None, *, model=None, var_names=None,

267

extend_inferencedata=True, progressbar=True):

268

"""

269

Compute log-prior density values.

270

271

Parameters:

272

- idata: InferenceData object with posterior samples

273

- model: PyMC model (default: current context)

274

- var_names (list, optional): Variables to compute log-prior for

275

- extend_inferencedata (bool): Add results to InferenceData

276

- progressbar (bool): Show progress bar

277

278

Returns:

279

- log_prior: Log-prior values for each variable and posterior sample

280

"""

281

282

# Compute log-prior

283

log_prior = pm.compute_log_prior(trace, model=model)

284

285

# Access log-prior values

286

prior_values = trace.log_prior

287

print("Log-prior components:")

288

for var_name in prior_values.data_vars:

289

values = prior_values[var_name]

290

print(f" {var_name}: mean = {values.mean():.3f}, std = {values.std():.3f}")

291

292

# Total log-prior per sample

293

total_prior = sum(prior_values[var].sum() for var in prior_values.data_vars)

294

```

295

296

## Posterior Analysis Utilities

297

298

### Summary Statistics

299

300

```python { .api }

301

# Summary statistics through ArviZ integration

302

summary_stats = pm.summary(trace, var_names=['alpha', 'beta'])

303

print("Posterior Summary:")

304

print(summary_stats)

305

306

# Custom summary with specific quantiles

307

custom_summary = pm.summary(trace,

308

stat_funcs={'median': np.median,

309

'mad': lambda x: np.median(np.abs(x - np.median(x)))},

310

extend=True)

311

312

# Round summary for reporting

313

rounded_summary = pm.summary(trace, round_to=3)

314

```

315

316

### Posterior Predictive Checks

317

318

```python { .api }

319

# Posterior predictive sampling for model checking

320

with pm.Model() as model:

321

# Model definition...

322

trace = pm.sample()

323

324

# Posterior predictive samples

325

post_pred = pm.sample_posterior_predictive(trace, predictions=True)

326

327

# Compare observed vs predicted

328

observed = post_pred.observed_data['y_obs']

329

predicted = post_pred.posterior_predictive['y_obs']

330

331

# T-test statistic for checking

332

def t_statistic(y):

333

return (y.mean() - observed.mean()) / (y.std() / np.sqrt(len(y)))

334

335

# Compute test statistic for observed and predicted

336

t_obs = t_statistic(observed.values)

337

t_pred = [t_statistic(pred_sample) for pred_sample in predicted.values.reshape(-1, len(observed))]

338

339

# Bayesian p-value

340

p_value = np.mean(np.abs(t_pred) >= np.abs(t_obs))

341

print(f"Bayesian p-value for mean difference: {p_value:.3f}")

342

```

343

344

## Advanced Diagnostics

345

346

### Energy Diagnostics

347

348

```python { .api }

349

# Access sampler statistics for energy diagnostics

350

sampler_stats = trace.get_sampler_stats()

351

352

# Energy statistics

353

energy = sampler_stats['energy']

354

energy_diff = np.diff(energy, axis=1) # Energy differences between steps

355

356

# Check for energy problems

357

mean_energy_diff = energy_diff.mean()

358

if abs(mean_energy_diff) > 0.2:

359

print(f"Warning: Large energy differences (mean = {mean_energy_diff:.3f})")

360

361

# Divergences

362

diverging = sampler_stats['diverging']

363

n_diverging = diverging.sum()

364

if n_diverging > 0:

365

print(f"Warning: {n_diverging} divergent transitions detected")

366

367

# Tree depth

368

treedepth = sampler_stats['treedepth']

369

max_treedepth = sampler_stats['max_treedepth']

370

saturated_trees = (treedepth >= max_treedepth).sum()

371

if saturated_trees > 0:

372

print(f"Warning: {saturated_trees} saturated trees (increase max_treedepth)")

373

```

374

375

### Custom Diagnostics

376

377

```python { .api }

378

def compute_split_rhat(trace, var_name):

379

"""Compute split R-hat manually for understanding."""

380

381

# Get samples for variable

382

samples = trace.posterior[var_name].values # Shape: (chains, draws, ...)

383

n_chains, n_draws = samples.shape[:2]

384

385

# Split each chain in half

386

first_half = samples[:, :n_draws//2]

387

second_half = samples[:, n_draws//2:]

388

389

# Combine split chains

390

split_samples = np.concatenate([first_half, second_half], axis=0)

391

392

# Between-chain variance

393

chain_means = split_samples.mean(axis=1)

394

overall_mean = chain_means.mean()

395

B = n_draws//2 * np.var(chain_means, ddof=1)

396

397

# Within-chain variance

398

chain_vars = split_samples.var(axis=1, ddof=1)

399

W = chain_vars.mean()

400

401

# Marginal posterior variance estimate

402

var_hat = (n_draws//2 - 1) / (n_draws//2) * W + B / (n_draws//2)

403

404

# R-hat

405

rhat = np.sqrt(var_hat / W)

406

407

return rhat

408

409

# Usage

410

manual_rhat = compute_split_rhat(trace, 'alpha')

411

print(f"Manual R-hat calculation: {manual_rhat:.4f}")

412

```

413

414

### Rank Normalization Diagnostics

415

416

```python { .api }

417

def rank_normalized_split_rhat(data, var_names=None):

418

"""

419

Compute rank-normalized split R-hat (more robust version).

420

421

Parameters:

422

- data: InferenceData object

423

- var_names (list, optional): Variables to analyze

424

425

Returns:

426

- rhat_rank: Rank-normalized R-hat values

427

"""

428

429

# More robust R-hat using rank normalization

430

rhat_rank = pm.rank_normalized_split_rhat(trace)

431

print("Rank-normalized R-hat:")

432

for var, rhat_val in rhat_rank.items():

433

print(f" {var}: {rhat_val:.4f}")

434

if rhat_val > 1.01:

435

print(f" Warning: {var} may not have converged")

436

```

437

438

## Diagnostic Workflows

439

440

### Comprehensive Convergence Check

441

442

```python { .api }

443

def full_convergence_check(trace, model_name="Model"):

444

"""Comprehensive convergence assessment."""

445

446

print(f"=== Convergence Diagnostics for {model_name} ===")

447

448

# R-hat

449

rhat_vals = pm.rhat(trace)

450

max_rhat = max(rhat_vals.values())

451

print(f"Max R-hat: {max_rhat:.4f}")

452

453

# Effective sample size

454

ess_bulk = pm.ess(trace, method='bulk')

455

ess_tail = pm.ess(trace, method='tail')

456

min_ess_bulk = min(ess_bulk.values())

457

min_ess_tail = min(ess_tail.values())

458

print(f"Min ESS (bulk): {min_ess_bulk:.0f}")

459

print(f"Min ESS (tail): {min_ess_tail:.0f}")

460

461

# Sampler diagnostics

462

n_diverging = trace.get_sampler_stats('diverging').sum()

463

print(f"Diverging transitions: {n_diverging}")

464

465

# Overall assessment

466

converged = (max_rhat < 1.01 and min_ess_bulk > 400 and

467

min_ess_tail > 400 and n_diverging == 0)

468

469

print(f"Overall convergence: {'✓ PASS' if converged else '✗ FAIL'}")

470

471

return converged

472

473

# Usage

474

convergence_ok = full_convergence_check(trace, "Regression Model")

475

```

476

477

### Model Quality Assessment

478

479

```python { .api }

480

def assess_model_quality(trace, observed_data, model):

481

"""Comprehensive model quality assessment."""

482

483

print("=== Model Quality Assessment ===")

484

485

# Information criteria

486

loo_result = pm.loo(trace)

487

waic_result = pm.waic(trace)

488

489

print(f"LOO ELPD: {loo_result.elpd_loo:.2f} ± {loo_result.se:.2f}")

490

print(f"WAIC ELPD: {waic_result.elpd_waic:.2f} ± {waic_result.se:.2f}")

491

492

# Check for high Pareto k values

493

high_k = (loo_result.pareto_k > 0.7).sum()

494

if high_k > 0:

495

print(f"Warning: {high_k} observations have unreliable LOO estimates")

496

497

# Posterior predictive checks

498

post_pred = pm.sample_posterior_predictive(trace, model=model)

499

500

# Simple residual check

501

y_obs = observed_data

502

y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])

503

residuals = y_obs - y_pred_mean

504

505

print(f"Mean absolute residual: {np.abs(residuals).mean():.3f}")

506

print(f"Residual std: {residuals.std():.3f}")

507

508

return loo_result, waic_result, residuals

509

510

# Usage

511

loo, waic, residuals = assess_model_quality(trace, y_data, model)

512

```

513

514

PyMC's statistics and diagnostics framework, built on ArviZ integration, provides essential tools for validating Bayesian models and ensuring reliable inference results.