0
# Diagnostics
1
2
NumPyro provides comprehensive diagnostic utilities for assessing MCMC convergence, computing effective sample sizes, and summarizing posterior distributions. These tools are essential for validating the quality of Bayesian inference results and ensuring reliable posterior estimates.
3
4
## Capabilities
5
6
### Convergence Diagnostics
7
8
Functions for assessing MCMC chain convergence and mixing.
9
10
```python { .api }
11
def gelman_rubin(x: NDArray) -> NDArray:
12
"""
13
Compute Gelman-Rubin convergence diagnostic (R-hat statistic).
14
15
Assesses convergence by comparing within-chain and between-chain variances.
16
Values close to 1.0 indicate convergence; values > 1.1 suggest lack of convergence.
17
18
Args:
19
x: MCMC samples with shape (num_chains, num_samples, ...) or
20
(num_chains, num_samples)
21
22
Returns:
23
R-hat statistic for each parameter. Values near 1.0 indicate convergence.
24
25
Usage:
26
# Get samples from MCMC
27
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
28
mcmc.run(rng_key, data)
29
samples = mcmc.get_samples(group_by_chain=True)
30
31
# Compute R-hat for each parameter
32
rhat = numpyro.diagnostics.gelman_rubin(samples['theta'])
33
print(f"R-hat for theta: {rhat}")
34
35
# Check convergence (should be < 1.1)
36
converged = jnp.all(rhat < 1.1)
37
"""
38
39
def split_gelman_rubin(x: NDArray) -> NDArray:
40
"""
41
Compute split Gelman-Rubin diagnostic (split R-hat).
42
43
More robust version of R-hat that splits each chain in half to increase
44
the number of chains for better convergence assessment.
45
46
Args:
47
x: MCMC samples with shape (num_chains, num_samples, ...)
48
49
Returns:
50
Split R-hat statistic for each parameter
51
52
Usage:
53
# More robust convergence assessment
54
split_rhat = numpyro.diagnostics.split_gelman_rubin(samples['theta'])
55
print(f"Split R-hat: {split_rhat}")
56
57
# This is generally more reliable than regular R-hat
58
converged = jnp.all(split_rhat < 1.1)
59
"""
60
61
def effective_sample_size(x: NDArray) -> NDArray:
62
"""
63
Compute effective sample size (ESS) for MCMC chains.
64
65
ESS estimates the number of independent samples that would provide
66
the same statistical power as the correlated MCMC samples.
67
68
Args:
69
x: MCMC samples with shape (num_chains, num_samples, ...)
70
71
Returns:
72
Effective sample size for each parameter
73
74
Usage:
75
# Assess sampling efficiency
76
ess = numpyro.diagnostics.effective_sample_size(samples['theta'])
77
print(f"Effective sample size: {ess}")
78
79
# Rule of thumb: ESS should be > 100 for reliable estimates
80
# ESS > 400 is generally considered good
81
total_samples = samples['theta'].shape[0] * samples['theta'].shape[1]
82
efficiency = ess / total_samples
83
print(f"Sampling efficiency: {efficiency:.2%}")
84
"""
85
```
86
87
### Autocorrelation Analysis
88
89
Functions for analyzing temporal correlations in MCMC samples.
90
91
```python { .api }
92
def autocorrelation(x: NDArray) -> NDArray:
93
"""
94
Compute autocorrelation function for MCMC chains.
95
96
Measures how correlated a time series is with lagged versions of itself.
97
Useful for understanding the temporal structure of MCMC samples.
98
99
Args:
100
x: MCMC samples with shape (num_samples,) or (num_samples, num_features)
101
102
Returns:
103
Autocorrelation function values for different lags
104
105
Usage:
106
# Analyze autocorrelation structure
107
# First flatten chains if multiple chains
108
flat_samples = samples['theta'].reshape(-1) # (total_samples,)
109
autocorr = numpyro.diagnostics.autocorrelation(flat_samples)
110
111
# Plot autocorrelation to assess mixing
112
import matplotlib.pyplot as plt
113
plt.plot(autocorr[:100]) # First 100 lags
114
plt.xlabel('Lag')
115
plt.ylabel('Autocorrelation')
116
plt.title('MCMC Autocorrelation')
117
"""
118
119
def autocovariance(x: NDArray) -> NDArray:
120
"""
121
Compute autocovariance function for MCMC chains.
122
123
Similar to autocorrelation but without normalization, preserving
124
the actual variance scale of the correlations.
125
126
Args:
127
x: MCMC samples with shape (num_samples,) or (num_samples, num_features)
128
129
Returns:
130
Autocovariance function values for different lags
131
132
Usage:
133
# Compute autocovariance for variance analysis
134
flat_samples = samples['theta'].reshape(-1)
135
autocov = numpyro.diagnostics.autocovariance(flat_samples)
136
137
# First value is the variance
138
variance = autocov[0]
139
print(f"Sample variance: {variance}")
140
"""
141
```
142
143
### Posterior Summary Statistics
144
145
Functions for summarizing posterior distributions.
146
147
```python { .api }
148
def hpdi(x: NDArray, prob: float = 0.9, axis: int = 0) -> NDArray:
149
"""
150
Compute Highest Posterior Density Interval (HPDI).
151
152
HPDI is the shortest interval that contains the specified probability mass.
153
More informative than equal-tailed intervals for skewed distributions.
154
155
Args:
156
x: Posterior samples
157
prob: Probability mass to include in interval (default: 0.9)
158
axis: Axis along which to compute intervals (default: 0)
159
160
Returns:
161
Array with shape (..., 2) containing lower and upper bounds
162
163
Usage:
164
# 90% highest posterior density interval
165
hpdi_90 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.9)
166
print(f"90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")
167
168
# 95% HPDI for comparison
169
hpdi_95 = numpyro.diagnostics.hpdi(samples['theta'], prob=0.95)
170
print(f"95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")
171
172
# For multivariate parameters
173
multivar_hpdi = numpyro.diagnostics.hpdi(samples['weights'], prob=0.9)
174
# Shape: (num_parameters, 2)
175
"""
176
177
def print_summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> None:
178
"""
179
Print comprehensive summary statistics for posterior samples.
180
181
Provides mean, standard deviation, HPDI, effective sample size, and R-hat
182
for all parameters in a formatted table.
183
184
Args:
185
samples: Dictionary of posterior samples from MCMC
186
prob: Probability for HPDI computation (default: 0.9)
187
group_by_chain: Whether samples are grouped by chain
188
189
Usage:
190
# Get samples and print summary
191
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
192
mcmc.run(rng_key, data)
193
samples = mcmc.get_samples(group_by_chain=True)
194
195
# Print comprehensive summary
196
numpyro.diagnostics.print_summary(samples, prob=0.95)
197
198
# Output format:
199
# mean std median 90.0% n_eff r_hat
200
# theta 1.23 0.45 1.20 [0.56, 1.91] 892.5 1.002
201
# sigma 2.34 0.12 2.33 [2.14, 2.56] 1205.2 1.001
202
"""
203
204
def summary(samples: dict, prob: float = 0.9, group_by_chain: bool = True) -> dict:
205
"""
206
Compute summary statistics for posterior samples without printing.
207
208
Args:
209
samples: Dictionary of posterior samples
210
prob: Probability for HPDI computation
211
group_by_chain: Whether samples are grouped by chain
212
213
Returns:
214
Dictionary containing summary statistics for each parameter
215
216
Usage:
217
# Get summary as dictionary for further processing
218
summary_stats = numpyro.diagnostics.summary(samples, prob=0.95)
219
220
for param_name, stats in summary_stats.items():
221
print(f"{param_name}:")
222
print(f" Mean: {stats['mean']:.3f}")
223
print(f" Std: {stats['std']:.3f}")
224
print(f" R-hat: {stats['r_hat']:.3f}")
225
print(f" ESS: {stats['n_eff']:.1f}")
226
"""
227
```
228
229
### Model Diagnostics
230
231
Functions for diagnosing model-specific issues.
232
233
```python { .api }
234
def split_by_chain(x: NDArray) -> NDArray:
235
"""
236
Split samples by chain for chain-specific analysis.
237
238
Args:
239
x: Samples with shape (num_chains, num_samples, ...)
240
241
Returns:
242
List of arrays, one per chain
243
244
Usage:
245
# Analyze chains separately
246
chain_samples = numpyro.diagnostics.split_by_chain(samples['theta'])
247
248
for i, chain in enumerate(chain_samples):
249
mean_i = jnp.mean(chain)
250
print(f"Chain {i} mean: {mean_i:.3f}")
251
"""
252
253
def potential_scale_reduction(x: NDArray, split_chains: bool = True) -> NDArray:
254
"""
255
Compute potential scale reduction factor (PSRF).
256
257
Also known as R-hat, measures the ratio of the average variance of samples
258
within each chain to the variance of the pooled samples across chains.
259
260
Args:
261
x: MCMC samples with shape (num_chains, num_samples, ...)
262
split_chains: Whether to split chains for more robust estimates
263
264
Returns:
265
PSRF values for each parameter
266
267
Usage:
268
# Alternative interface to gelman_rubin
269
psrf = numpyro.diagnostics.potential_scale_reduction(samples['theta'])
270
print(f"PSRF: {psrf}")
271
"""
272
273
def rank_plot_data(samples: dict, param_names: Optional[list] = None) -> dict:
274
"""
275
Prepare data for rank plots (for external plotting).
276
277
Rank plots help visualize chain mixing by showing the distribution
278
of ranks of samples from different chains.
279
280
Args:
281
samples: Dictionary of MCMC samples
282
param_names: List of parameter names to include
283
284
Returns:
285
Dictionary with rank data for plotting
286
287
Usage:
288
# Prepare data for rank plots
289
rank_data = numpyro.diagnostics.rank_plot_data(samples, ['theta', 'sigma'])
290
291
# Use with external plotting library
292
import matplotlib.pyplot as plt
293
for param, ranks in rank_data.items():
294
plt.figure()
295
for chain_ranks in ranks:
296
plt.hist(chain_ranks, alpha=0.5, bins=50)
297
plt.title(f"Rank plot for {param}")
298
"""
299
```
300
301
### Diagnostic Utilities
302
303
Helper functions for diagnostic computations.
304
305
```python { .api }
306
def within_chain_variance(x: NDArray) -> NDArray:
307
"""
308
Compute within-chain variance for R-hat calculation.
309
310
Args:
311
x: MCMC samples with shape (num_chains, num_samples, ...)
312
313
Returns:
314
Within-chain variance for each parameter
315
"""
316
317
def between_chain_variance(x: NDArray) -> NDArray:
318
"""
319
Compute between-chain variance for R-hat calculation.
320
321
Args:
322
x: MCMC samples with shape (num_chains, num_samples, ...)
323
324
Returns:
325
Between-chain variance for each parameter
326
"""
327
328
def integrated_autocorr_time(x: NDArray, c: float = 5.0,
329
tol: float = 50.0, quiet: bool = False) -> float:
330
"""
331
Compute integrated autocorrelation time.
332
333
Estimates the correlation time by integrating the autocorrelation function
334
until it becomes unreliable.
335
336
Args:
337
x: Time series data
338
c: Window size multiplier for automatic windowing
339
tol: Tolerance for unreliable estimates
340
quiet: Whether to suppress warnings
341
342
Returns:
343
Integrated autocorrelation time
344
345
Usage:
346
# Estimate correlation time
347
flat_samples = samples['theta'].reshape(-1)
348
tau = numpyro.diagnostics.integrated_autocorr_time(flat_samples)
349
print(f"Autocorrelation time: {tau:.2f}")
350
351
# Rule of thumb: need at least 50*tau samples for reliable estimates
352
min_samples = 50 * tau
353
actual_samples = len(flat_samples)
354
print(f"Recommended samples: {min_samples:.0f}, Actual: {actual_samples}")
355
"""
356
357
def compute_chain_statistics(x: NDArray) -> dict:
358
"""
359
Compute comprehensive statistics for individual chains.
360
361
Args:
362
x: MCMC samples with shape (num_chains, num_samples, ...)
363
364
Returns:
365
Dictionary with statistics for each chain
366
367
Usage:
368
# Analyze individual chain performance
369
chain_stats = numpyro.diagnostics.compute_chain_statistics(samples['theta'])
370
371
for chain_id, stats in chain_stats.items():
372
print(f"Chain {chain_id}:")
373
print(f" Mean: {stats['mean']:.3f}")
374
print(f" Variance: {stats['var']:.3f}")
375
print(f" ESS: {stats['ess']:.1f}")
376
"""
377
```
378
379
## Usage Examples
380
381
```python
382
import numpyro
383
import numpyro.distributions as dist
384
from numpyro.infer import MCMC, NUTS
385
import numpyro.diagnostics as diagnostics
386
import jax.numpy as jnp
387
from jax import random
388
389
# Comprehensive diagnostic workflow
390
def diagnostic_workflow_example():
391
# Define a simple model
392
def model(x, y=None):
393
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
394
beta = numpyro.sample("beta", dist.Normal(0, 1))
395
sigma = numpyro.sample("sigma", dist.Exponential(1))
396
397
mu = alpha + beta * x
398
with numpyro.plate("data", len(x)):
399
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
400
401
# Generate synthetic data
402
key = random.PRNGKey(0)
403
n_data = 100
404
x = jnp.linspace(0, 1, n_data)
405
true_alpha, true_beta, true_sigma = 1.0, 2.0, 0.1
406
y = true_alpha + true_beta * x + true_sigma * random.normal(key, (n_data,))
407
408
# Run MCMC with multiple chains
409
kernel = NUTS(model)
410
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
411
mcmc.run(random.PRNGKey(1), x, y)
412
413
# Get samples grouped by chain for diagnostics
414
samples = mcmc.get_samples(group_by_chain=True)
415
416
print("=== MCMC Diagnostic Report ===")
417
418
# 1. Print comprehensive summary
419
print("\\n1. Summary Statistics:")
420
diagnostics.print_summary(samples, prob=0.95)
421
422
# 2. Check convergence with R-hat
423
print("\\n2. Convergence Diagnostics:")
424
for param_name, param_samples in samples.items():
425
rhat = diagnostics.gelman_rubin(param_samples)
426
split_rhat = diagnostics.split_gelman_rubin(param_samples)
427
428
print(f"{param_name}:")
429
print(f" R-hat: {rhat:.4f}")
430
print(f" Split R-hat: {split_rhat:.4f}")
431
print(f" Converged (R-hat < 1.1): {rhat < 1.1}")
432
433
# 3. Assess sampling efficiency
434
print("\\n3. Sampling Efficiency:")
435
total_samples = samples['alpha'].shape[0] * samples['alpha'].shape[1]
436
437
for param_name, param_samples in samples.items():
438
ess = diagnostics.effective_sample_size(param_samples)
439
efficiency = ess / total_samples
440
441
print(f"{param_name}:")
442
print(f" ESS: {ess:.1f}")
443
print(f" Efficiency: {efficiency:.2%}")
444
print(f" Good ESS (>400): {ess > 400}")
445
446
# 4. Posterior intervals
447
print("\\n4. Posterior Intervals:")
448
flat_samples = {k: v.reshape(-1) for k, v in samples.items()}
449
450
for param_name, param_samples in flat_samples.items():
451
hpdi_90 = diagnostics.hpdi(param_samples, prob=0.9)
452
hpdi_95 = diagnostics.hpdi(param_samples, prob=0.95)
453
454
print(f"{param_name}:")
455
print(f" 90% HPDI: [{hpdi_90[0]:.3f}, {hpdi_90[1]:.3f}]")
456
print(f" 95% HPDI: [{hpdi_95[0]:.3f}, {hpdi_95[1]:.3f}]")
457
458
# 5. Autocorrelation analysis
459
print("\\n5. Autocorrelation Analysis:")
460
for param_name, param_samples in flat_samples.items():
461
autocorr = diagnostics.autocorrelation(param_samples)
462
tau = diagnostics.integrated_autocorr_time(param_samples, quiet=True)
463
464
print(f"{param_name}:")
465
print(f" Autocorr time: {tau:.2f}")
466
print(f" Recommended min samples: {50 * tau:.0f}")
467
print(f" Actual samples: {len(param_samples)}")
468
469
return samples
470
471
# Chain-specific diagnostics
472
def chain_analysis_example():
473
# Assume we have samples from previous example
474
# samples = ... (from MCMC run)
475
476
# Analyze individual chains
477
param_samples = samples['alpha'] # Shape: (num_chains, num_samples)
478
479
print("=== Individual Chain Analysis ===")
480
481
# Split by chain and analyze separately
482
chain_samples = diagnostics.split_by_chain(param_samples)
483
484
for i, chain in enumerate(chain_samples):
485
mean_i = jnp.mean(chain)
486
std_i = jnp.std(chain)
487
autocorr_i = diagnostics.autocorrelation(chain)
488
489
print(f"\\nChain {i}:")
490
print(f" Mean: {mean_i:.4f}")
491
print(f" Std: {std_i:.4f}")
492
print(f" First 5 autocorr values: {autocorr_i[:5]}")
493
494
# Compare within vs between chain variance
495
within_var = diagnostics.within_chain_variance(param_samples)
496
between_var = diagnostics.between_chain_variance(param_samples)
497
498
print(f"\\nVariance Analysis:")
499
print(f" Within-chain variance: {within_var:.6f}")
500
print(f" Between-chain variance: {between_var:.6f}")
501
print(f" Ratio (should be ~1): {between_var / within_var:.4f}")
502
503
# Diagnostic-driven sampling strategy
504
def adaptive_sampling_example():
505
"""Example of using diagnostics to determine sampling requirements."""
506
507
def model():
508
# Deliberately create a challenging posterior
509
x = numpyro.sample("x", dist.Normal(0, 1))
510
y = numpyro.sample("y", dist.Normal(x**2, 0.1)) # Non-linear relationship
511
512
# Start with small number of samples
513
initial_samples = 500
514
target_ess = 400
515
max_iterations = 5
516
517
for iteration in range(max_iterations):
518
print(f"\\n--- Iteration {iteration + 1} ---")
519
520
# Run MCMC
521
mcmc = MCMC(NUTS(model),
522
num_warmup=initial_samples,
523
num_samples=initial_samples,
524
num_chains=4)
525
mcmc.run(random.PRNGKey(iteration))
526
527
samples = mcmc.get_samples(group_by_chain=True)
528
529
# Check diagnostics
530
rhat = diagnostics.gelman_rubin(samples['x'])
531
ess = diagnostics.effective_sample_size(samples['x'])
532
533
print(f"Current samples per chain: {initial_samples}")
534
print(f"R-hat: {rhat:.4f}")
535
print(f"ESS: {ess:.1f}")
536
537
# Check if we meet convergence criteria
538
converged = rhat < 1.1
539
sufficient_ess = ess > target_ess
540
541
if converged and sufficient_ess:
542
print(f"✓ Convergence achieved!")
543
break
544
elif not converged:
545
print(f"✗ Poor convergence (R-hat = {rhat:.4f})")
546
initial_samples = int(initial_samples * 1.5) # Increase samples
547
elif not sufficient_ess:
548
print(f"✗ Insufficient ESS ({ess:.1f} < {target_ess})")
549
initial_samples = int(initial_samples * 1.2) # Modest increase
550
551
return samples
552
```
553
554
## Types
555
556
```python { .api }
557
from typing import Optional, Union, Dict, Any, List
558
from jax import Array
559
import jax.numpy as jnp
560
561
NDArray = jnp.ndarray
562
ArrayLike = Union[Array, NDArray, float, int]
563
Samples = Dict[str, NDArray]
564
565
class DiagnosticResult:
566
"""Base class for diagnostic results."""
567
pass
568
569
class SummaryStats:
570
"""Summary statistics for a parameter."""
571
mean: float
572
std: float
573
median: float
574
mad: float # Median absolute deviation
575
hpdi_lower: float
576
hpdi_upper: float
577
n_eff: float # Effective sample size
578
r_hat: float # R-hat statistic
579
580
class ConvergenceDiagnostic:
581
"""Convergence diagnostic results."""
582
r_hat: NDArray
583
split_r_hat: NDArray
584
converged: bool
585
potential_scale_reduction: NDArray
586
587
class EfficiencyDiagnostic:
588
"""Sampling efficiency diagnostic results."""
589
effective_sample_size: NDArray
590
autocorrelation_time: NDArray
591
efficiency_ratio: NDArray
592
593
class AutocorrelationResult:
594
"""Autocorrelation analysis results."""
595
autocorr: NDArray
596
autocov: NDArray
597
integrated_time: float
598
599
class ChainStatistics:
600
"""Statistics for individual MCMC chains."""
601
chain_id: int
602
mean: NDArray
603
variance: NDArray
604
effective_sample_size: NDArray
605
autocorrelation_time: float
606
607
# Function type signatures
608
ConvergenceFunction = Callable[[NDArray], NDArray]
609
SummaryFunction = Callable[[NDArray], Dict[str, Any]]
610
DiagnosticFunction = Callable[[NDArray], DiagnosticResult]
611
```