or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

diagnostics.mddocs/

0

# Diagnostics

1

2

NumPyro provides comprehensive diagnostic utilities for assessing MCMC convergence, computing effective sample sizes, and summarizing posterior distributions. These tools are essential for validating the quality of Bayesian inference results and ensuring reliable posterior estimates.

3

4

## Capabilities

5

6

### Convergence Diagnostics

7

8

Functions for assessing MCMC chain convergence and mixing.

9

10

```python { .api }

11

def gelman_rubin(x: NDArray) -> NDArray:

12

"""

13

Compute Gelman-Rubin convergence diagnostic (R-hat statistic).

14

15

Assesses convergence by comparing within-chain and between-chain variances.

16

Values close to 1.0 indicate convergence; values > 1.1 suggest lack of convergence.

17

18

Args:

19

x: MCMC samples with shape (num_chains, num_samples, ...) or

20

(num_chains, num_samples)

21

22

Returns:

23

R-hat statistic for each parameter. Values near 1.0 indicate convergence.

24

25

Usage:

26

# Get samples from MCMC

27

mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)

28

mcmc.run(rng_key, data)

29

samples = mcmc.get_samples(group_by_chain=True)

30

31

# Compute R-hat for each parameter

32

rhat = numpyro.diagnostics.gelman_rubin(samples['theta'])

33

print(f"R-hat for theta: {rhat}")

34

35

# Check convergence (should be < 1.1)

36

converged = jnp.all(rhat < 1.1)

37

"""

38

39

def split_gelman_rubin(x: NDArray) -> NDArray:

40

"""

41

Compute split Gelman-Rubin diagnostic (split R-hat).

42

43

More robust version of R-hat that splits each chain in half to increase

44

the number of chains for better convergence assessment.

45

46

Args:

47

x: MCMC samples with shape (num_chains, num_samples, ...)

48

49

Returns:

50

Split R-hat statistic for each parameter

51

52

Usage:

53

# More robust convergence assessment

54

split_rhat = numpyro.diagnostics.split_gelman_rubin(samples['theta'])

55

print(f"Split R-hat: {split_rhat}")

56

57

# This is generally more reliable than regular R-hat

58

converged = jnp.all(split_rhat < 1.1)

59

"""

60

61

def effective_sample_size(x: NDArray) -> NDArray:

62

"""

63

Compute effective sample size (ESS) for MCMC chains.

64

65

ESS estimates the number of independent samples that would provide

66

the same statistical power as the correlated MCMC samples.

67

68

Args:

69

x: MCMC samples with shape (num_chains, num_samples, ...)

70

71

Returns:

72

Effective sample size for each parameter

73

74

Usage:

75

# Assess sampling efficiency

76

ess = numpyro.diagnostics.effective_sample_size(samples['theta'])

77

print(f"Effective sample size: {ess}")

78

79

# Rule of thumb: ESS should be > 100 for reliable estimates

80

# ESS > 400 is generally considered good

81

total_samples = samples['theta'].shape[0] * samples['theta'].shape[1]

82

efficiency = ess / total_samples

83

print(f"Sampling efficiency: {efficiency:.2%}")

84

"""

85

```

86

87

### Autocorrelation Analysis

88

89

Functions for analyzing temporal correlations in MCMC samples.

90

91

```python { .api }

92

def autocorrelation(x: NDArray) -> NDArray:

93

"""

94

Compute autocorrelation function for MCMC chains.

95

96

Measures how correlated a time series is with lagged versions of itself.

97

Useful for understanding the temporal structure of MCMC samples.

98

99

Args:

100

x: MCMC samples with shape (num_samples,) or (num_samples, num_features)

101

102

Returns:

103

Autocorrelation function values for different lags

104

105

Usage:

106

# Analyze autocorrelation structure

107

# First flatten chains if multiple chains

108

flat_samples = samples['theta'].reshape(-1) # (total_samples,)

109

autocorr = numpyro.diagnostics.autocorrelation(flat_samples)

110

111

# Plot autocorrelation to assess mixing

112

import matplotlib.pyplot as plt

113

plt.plot(autocorr[:100]) # First 100 lags

114

plt.xlabel('Lag')

115

plt.ylabel('Autocorrelation')

116

plt.title('MCMC Autocorrelation')

117

"""

118

119

def autocovariance(x: NDArray) -> NDArray:

120

"""

121

Compute autocovariance function for MCMC chains.

122

123

Similar to autocorrelation but without normalization, preserving

124

the actual variance scale of the correlations.

125

126

Args:

127

x: MCMC samples with shape (num_samples,) or (num_samples, num_features)

128

129

Returns:

130

Autocovariance function values for different lags

131

132

Usage:

133

# Compute autocovariance for variance analysis

134

flat_samples = samples['theta'].reshape(-1)

135

autocov = numpyro.diagnostics.autocovariance(flat_samples)

136

137

# First value is the variance

138

variance = autocov[0]

139

print(f"Sample variance: {variance}")

140

"""

141

```

142

143

### Posterior Summary Statistics

144

145

Functions for summarizing posterior distributions.

146

147

```python { .api }

148

def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray:

149

"""

150

Compute Highest Posterior Density Interval (HPDI).

151

152

HPDI is the shortest interval that contains the specified probability mass.

153

More informative than equal-tailed intervals for skewed distributions.

154

155

Args:

156

x: Posterior samples

157

prob: Probability mass to include in interval (default: 0.9)

158

axis: Axis along which to compute intervals (default: 0)

159

160

Returns:

161

Array with shape (..., 2) containing lower and upper bounds

162

163

Usage:

164

# 90% highest posterior density interval

165

hpdi_90 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.9)

166

print(f"90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")

167

168

# 95% HPDI for comparison

169

hpdi_95 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.95)

170

print(f"95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")

171

172

# For multivariate parameters

173

multivar_hpdi = numpyro.diagnostics.hpdi(samples['weights'], prob=0.9)

174

# Shape: (num_parameters, 2)

175

"""

176

177

def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None:

178

"""

179

Print comprehensive summary statistics for posterior samples.

180

181

Provides mean, standard deviation, HPDI, effective sample size, and R-hat

182

for all parameters in a formatted table.

183

184

Args:

185

samples: Dictionary of posterior samples from MCMC

186

prob: Probability for HPDI computation (default: 0.9)

187

group_by_chain: Whether samples are grouped by chain

188

189

Usage:

190

# Get samples and print summary

191

mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)

192

mcmc.run(rng_key, data)

193

samples = mcmc.get_samples(group_by_chain=True)

194

195

# Print comprehensive summary

196

numpyro.diagnostics.print_summary(samples, prob=0.95)

197

198

# Output format:

199

# mean std median 90.0% n_eff r_hat

200

# theta 1.23 0.45 1.20 [0.56, 1.91] 892.5 1.002

201

# sigma 2.34 0.12 2.33 [2.14, 2.56] 1205.2 1.001

202

"""

203

204

def summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> dict:

205

"""

206

Compute summary statistics for posterior samples without printing.

207

208

Args:

209

samples: Dictionary of posterior samples

210

prob: Probability for HPDI computation

211

group_by_chain: Whether samples are grouped by chain

212

213

Returns:

214

Dictionary containing summary statistics for each parameter

215

216

Usage:

217

# Get summary as dictionary for further processing

218

summary_stats = numpyro.diagnostics.summary(samples, prob=0.95)

219

220

for param_name, stats in summary_stats.items():

221

print(f"{param_name}:")

222

print(f" Mean: {stats['mean']:.3f}")

223

print(f" Std: {stats['std']:.3f}")

224

print(f" R-hat: {stats['r_hat']:.3f}")

225

print(f" ESS: {stats['n_eff']:.1f}")

226

"""

227

```

228

229

### Model Diagnostics

230

231

Functions for diagnosing model-specific issues.

232

233

```python { .api }

234

def split_by_chain(x: NDArray) -> NDArray:

235

"""

236

Split samples by chain for chain-specific analysis.

237

238

Args:

239

x: Samples with shape (num_chains, num_samples, ...)

240

241

Returns:

242

List of arrays, one per chain

243

244

Usage:

245

# Analyze chains separately

246

chain_samples = numpyro.diagnostics.split_by_chain(samples['theta'])

247

248

for i, chain in enumerate(chain_samples):

249

mean_i = jnp.mean(chain)

250

print(f"Chain {i} mean: {mean_i:.3f}")

251

"""

252

253

def potential_scale_reduction(x: NDArray, split_chains: bool = True) -> NDArray:

254

"""

255

Compute potential scale reduction factor (PSRF).

256

257

Also known as R-hat, measures the ratio of the average variance of samples

258

within each chain to the variance of the pooled samples across chains.

259

260

Args:

261

x: MCMC samples with shape (num_chains, num_samples, ...)

262

split_chains: Whether to split chains for more robust estimates

263

264

Returns:

265

PSRF values for each parameter

266

267

Usage:

268

# Alternative interface to gelman_rubin

269

psrf = numpyro.diagnostics.potential_scale_reduction(samples['theta'])

270

print(f"PSRF: {psrf}")

271

"""

272

273

def rank_plot_data(samples: dict, param_names: Optional[list] = None) -> dict:

274

"""

275

Prepare data for rank plots (for external plotting).

276

277

Rank plots help visualize chain mixing by showing the distribution

278

of ranks of samples from different chains.

279

280

Args:

281

samples: Dictionary of MCMC samples

282

param_names: List of parameter names to include

283

284

Returns:

285

Dictionary with rank data for plotting

286

287

Usage:

288

# Prepare data for rank plots

289

rank_data = numpyro.diagnostics.rank_plot_data(samples, ['theta', 'sigma'])

290

291

# Use with external plotting library

292

import matplotlib.pyplot as plt

293

for param, ranks in rank_data.items():

294

plt.figure()

295

for chain_ranks in ranks:

296

plt.hist(chain_ranks, alpha=0.5, bins=50)

297

plt.title(f"Rank plot for {param}")

298

"""

299

```

300

301

### Diagnostic Utilities

302

303

Helper functions for diagnostic computations.

304

305

```python { .api }

306

def within_chain_variance(x: NDArray) -> NDArray:

307

"""

308

Compute within-chain variance for R-hat calculation.

309

310

Args:

311

x: MCMC samples with shape (num_chains, num_samples, ...)

312

313

Returns:

314

Within-chain variance for each parameter

315

"""

316

317

def between_chain_variance(x: NDArray) -> NDArray:

318

"""

319

Compute between-chain variance for R-hat calculation.

320

321

Args:

322

x: MCMC samples with shape (num_chains, num_samples, ...)

323

324

Returns:

325

Between-chain variance for each parameter

326

"""

327

328

def integrated_autocorr_time(x: NDArray, c: float = 5.0,

329

tol: float = 50.0, quiet: bool = False) -> float:

330

"""

331

Compute integrated autocorrelation time.

332

333

Estimates the correlation time by integrating the autocorrelation function

334

until it becomes unreliable.

335

336

Args:

337

x: Time series data

338

c: Window size multiplier for automatic windowing

339

tol: Tolerance for unreliable estimates

340

quiet: Whether to suppress warnings

341

342

Returns:

343

Integrated autocorrelation time

344

345

Usage:

346

# Estimate correlation time

347

flat_samples = samples['theta'].reshape(-1)

348

tau = numpyro.diagnostics.integrated_autocorr_time(flat_samples)

349

print(f"Autocorrelation time: {tau:.2f}")

350

351

# Rule of thumb: need at least 50*tau samples for reliable estimates

352

min_samples = 50 * tau

353

actual_samples = len(flat_samples)

354

print(f"Recommended samples: {min_samples:.0f}, Actual: {actual_samples}")

355

"""

356

357

def compute_chain_statistics(x: NDArray) -> dict:

358

"""

359

Compute comprehensive statistics for individual chains.

360

361

Args:

362

x: MCMC samples with shape (num_chains, num_samples, ...)

363

364

Returns:

365

Dictionary with statistics for each chain

366

367

Usage:

368

# Analyze individual chain performance

369

chain_stats = numpyro.diagnostics.compute_chain_statistics(samples['theta'])

370

371

for chain_id, stats in chain_stats.items():

372

print(f"Chain {chain_id}:")

373

print(f" Mean: {stats['mean']:.3f}")

374

print(f" Variance: {stats['var']:.3f}")

375

print(f" ESS: {stats['ess']:.1f}")

376

"""

377

```

378

379

## Usage Examples

380

381

```python

382

import numpyro

383

import numpyro.distributions as dist

384

from numpyro.infer import MCMC, NUTS

385

import numpyro.diagnostics as diagnostics

386

import jax.numpy as jnp

387

from jax import random

388

389

# Comprehensive diagnostic workflow

390

def diagnostic_workflow_example():

391

# Define a simple model

392

def model(x, y=None):

393

alpha = numpyro.sample("alpha", dist.Normal(0, 1))

394

beta = numpyro.sample("beta", dist.Normal(0, 1))

395

sigma = numpyro.sample("sigma", dist.Exponential(1))

396

397

mu = alpha + beta * x

398

with numpyro.plate("data", len(x)):

399

numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

400

401

# Generate synthetic data

402

key = random.PRNGKey(0)

403

n_data = 100

404

x = jnp.linspace(0, 1, n_data)

405

true_alpha, true_beta, true_sigma = 1.0, 2.0, 0.1

406

y = true_alpha + true_beta * x + true_sigma * random.normal(key, (n_data,))

407

408

# Run MCMC with multiple chains

409

kernel = NUTS(model)

410

mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)

411

mcmc.run(random.PRNGKey(1), x, y)

412

413

# Get samples grouped by chain for diagnostics

414

samples = mcmc.get_samples(group_by_chain=True)

415

416

print("=== MCMC Diagnostic Report ===")

417

418

# 1. Print comprehensive summary

419

print("\\n1. Summary Statistics:")

420

diagnostics.print_summary(samples, prob=0.95)

421

422

# 2. Check convergence with R-hat

423

print("\\n2. Convergence Diagnostics:")

424

for param_name, param_samples in samples.items():

425

rhat = diagnostics.gelman_rubin(param_samples)

426

split_rhat = diagnostics.split_gelman_rubin(param_samples)

427

428

print(f"{param_name}:")

429

print(f" R-hat: {rhat:.4f}")

430

print(f" Split R-hat: {split_rhat:.4f}")

431

print(f" Converged (R-hat < 1.1): {rhat < 1.1}")

432

433

# 3. Assess sampling efficiency

434

print("\\n3. Sampling Efficiency:")

435

total_samples = samples['alpha'].shape[0] * samples['alpha'].shape[1]

436

437

for param_name, param_samples in samples.items():

438

ess = diagnostics.effective_sample_size(param_samples)

439

efficiency = ess / total_samples

440

441

print(f"{param_name}:")

442

print(f" ESS: {ess:.1f}")

443

print(f" Efficiency: {efficiency:.2%}")

444

print(f" Good ESS (>400): {ess > 400}")

445

446

# 4. Posterior intervals

447

print("\\n4. Posterior Intervals:")

448

flat_samples = {k: v.reshape(-1) for k, v in samples.items()}

449

450

for param_name, param_samples in flat_samples.items():

451

hpdi_90 = diagnostics.hpdi(param_samples, prob=0.9)

452

hpdi_95 = diagnostics.hpdi(param_samples, prob=0.95)

453

454

print(f"{param_name}:")

455

print(f" 90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")

456

print(f" 95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")

457

458

# 5. Autocorrelation analysis

459

print("\\n5. Autocorrelation Analysis:")

460

for param_name, param_samples in flat_samples.items():

461

autocorr = diagnostics.autocorrelation(param_samples)

462

tau = diagnostics.integrated_autocorr_time(param_samples, quiet=True)

463

464

print(f"{param_name}:")

465

print(f" Autocorr time: {tau:.2f}")

466

print(f" Recommended min samples: {50 * tau:.0f}")

467

print(f" Actual samples: {len(param_samples)}")

468

469

return samples

470

471

# Chain-specific diagnostics

472

def chain_analysis_example():

473

# Assume we have samples from previous example

474

# samples = ... (from MCMC run)

475

476

# Analyze individual chains

477

param_samples = samples['alpha'] # Shape: (num_chains, num_samples)

478

479

print("=== Individual Chain Analysis ===")

480

481

# Split by chain and analyze separately

482

chain_samples = diagnostics.split_by_chain(param_samples)

483

484

for i, chain in enumerate(chain_samples):

485

mean_i = jnp.mean(chain)

486

std_i = jnp.std(chain)

487

autocorr_i = diagnostics.autocorrelation(chain)

488

489

print(f"\\nChain {i}:")

490

print(f" Mean: {mean_i:.4f}")

491

print(f" Std: {std_i:.4f}")

492

print(f" First 5 autocorr values: {autocorr_i[:5]}")

493

494

# Compare within vs between chain variance

495

within_var = diagnostics.within_chain_variance(param_samples)

496

between_var = diagnostics.between_chain_variance(param_samples)

497

498

print(f"\\nVariance Analysis:")

499

print(f" Within-chain variance: {within_var:.6f}")

500

print(f" Between-chain variance: {between_var:.6f}")

501

print(f" Ratio (should be ~1): {between_var / within_var:.4f}")

502

503

# Diagnostic-driven sampling strategy

504

def adaptive_sampling_example():

505

"""Example of using diagnostics to determine sampling requirements."""

506

507

def model():

508

# Deliberately create a challenging posterior

509

x = numpyro.sample("x", dist.Normal(0, 1))

510

y = numpyro.sample("y", dist.Normal(x**2, 0.1)) # Non-linear relationship

511

512

# Start with small number of samples

513

initial_samples = 500

514

target_ess = 400

515

max_iterations = 5

516

517

for iteration in range(max_iterations):

518

print(f"\\n--- Iteration {iteration + 1} ---")

519

520

# Run MCMC

521

mcmc = MCMC(NUTS(model),

522

num_warmup=initial_samples,

523

num_samples=initial_samples,

524

num_chains=4)

525

mcmc.run(random.PRNGKey(iteration))

526

527

samples = mcmc.get_samples(group_by_chain=True)

528

529

# Check diagnostics

530

rhat = diagnostics.gelman_rubin(samples['x'])

531

ess = diagnostics.effective_sample_size(samples['x'])

532

533

print(f"Current samples per chain: {initial_samples}")

534

print(f"R-hat: {rhat:.4f}")

535

print(f"ESS: {ess:.1f}")

536

537

# Check if we meet convergence criteria

538

converged = rhat < 1.1

539

sufficient_ess = ess > target_ess

540

541

if converged and sufficient_ess:

542

print(f"✓ Convergence achieved!")

543

break

544

elif not converged:

545

print(f"✗ Poor convergence (R-hat = {rhat:.4f})")

546

initial_samples = int(initial_samples * 1.5) # Increase samples

547

elif not sufficient_ess:

548

print(f"✗ Insufficient ESS ({ess:.1f} < {target_ess})")

549

initial_samples = int(initial_samples * 1.2) # Modest increase

550

551

return samples

552

```

553

554

## Types

555

556

```python { .api }

557

from typing import Optional, Union, Dict, Any, List

558

from jax import Array

559

import jax.numpy as jnp

560

561

NDArray = jnp.ndarray

562

ArrayLike = Union[Array, NDArray, float, int]

563

Samples = Dict[str, NDArray]

564

565

class DiagnosticResult:

566

"""Base class for diagnostic results."""

567

pass

568

569

class SummaryStats:

570

"""Summary statistics for a parameter."""

571

mean: float

572

std: float

573

median: float

574

mad: float # Median absolute deviation

575

hpdi_lower: float

576

hpdi_upper: float

577

n_eff: float # Effective sample size

578

r_hat: float # R-hat statistic

579

580

class ConvergenceDiagnostic:

581

"""Convergence diagnostic results."""

582

r_hat: NDArray

583

split_r_hat: NDArray

584

converged: bool

585

potential_scale_reduction: NDArray

586

587

class EfficiencyDiagnostic:

588

"""Sampling efficiency diagnostic results."""

589

effective_sample_size: NDArray

590

autocorrelation_time: NDArray

591

efficiency_ratio: NDArray

592

593

class AutocorrelationResult:

594

"""Autocorrelation analysis results."""

595

autocorr: NDArray

596

autocov: NDArray

597

integrated_time: float

598

599

class ChainStatistics:

600

"""Statistics for individual MCMC chains."""

601

chain_id: int

602

mean: NDArray

603

variance: NDArray

604

effective_sample_size: NDArray

605

autocorrelation_time: float

606

607

# Function type signatures

608

ConvergenceFunction = Callable[[NDArray], NDArray]

609

SummaryFunction = Callable[[NDArray], Dict[str, Any]]

610

DiagnosticFunction = Callable[[NDArray], DiagnosticResult]

611

```