0
# Framework Integrations
1
2
Convert inference results from various probabilistic programming frameworks to ArviZ's unified InferenceData format. Supports Stan (CmdStan, PyStan, CmdStanPy), PyMC, Pyro, NumPyro, JAX, emcee, and more.
3
4
## Stan Ecosystem
5
6
### CmdStan and CmdStanPy
7
8
```python { .api }
9
def from_cmdstan(posterior: str = None, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, **kwargs) -> InferenceData:
10
"""
11
Convert CmdStan output files to InferenceData.
12
13
Args:
14
posterior (str, optional): Path to posterior samples CSV file
15
posterior_predictive (str, optional): Path to posterior predictive CSV
16
observed_data (dict, optional): Dictionary of observed data
17
constant_data (dict, optional): Dictionary of constant/fixed data
18
predictions (dict, optional): Dictionary of out-of-sample predictions
19
**kwargs: Additional conversion parameters (coords, dims, etc.)
20
21
Returns:
22
InferenceData: Converted inference data object
23
"""
24
25
def from_cmdstanpy(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:
26
"""
27
Convert CmdStanPy fit results to InferenceData.
28
29
Args:
30
fit: CmdStanPy fit object (CmdStanMCMC, CmdStanMLE, CmdStanVB)
31
posterior_predictive (str, optional): Variable name for posterior predictive
32
observed_data (dict, optional): Dictionary of observed data
33
constant_data (dict, optional): Dictionary of constant data
34
**kwargs: Additional conversion parameters
35
36
Returns:
37
InferenceData: Converted inference data object
38
"""
39
```
40
41
### PyStan
42
43
```python { .api }
44
def from_pystan(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:
45
"""
46
Convert PyStan fit results to InferenceData.
47
48
Args:
49
fit: PyStan fit object (StanFit4Model)
50
posterior_predictive (str, optional): Variable name for posterior predictive
51
observed_data (dict, optional): Dictionary of observed data
52
constant_data (dict, optional): Dictionary of constant data
53
**kwargs: Additional conversion parameters (coords, dims, etc.)
54
55
Returns:
56
InferenceData: Converted inference data object
57
"""
58
```
59
60
### Usage Examples
61
62
```python
63
import arviz as az
64
import cmdstanpy
65
66
# CmdStanPy example
67
model = cmdstanpy.CmdStanModel(stan_file="model.stan")
68
fit = model.sample(data=data_dict)
69
idata = az.from_cmdstanpy(fit, observed_data={"y": y_obs})
70
71
# CmdStan CSV files
72
idata = az.from_cmdstan(
73
posterior="output.csv",
74
posterior_predictive="predictions.csv",
75
observed_data={"y": y_obs}
76
)
77
78
# PyStan example (legacy)
79
import pystan
80
model = pystan.StanModel(file="model.stan")
81
fit = model.sampling(data=data_dict)
82
idata = az.from_pystan(fit, observed_data={"y": y_obs})
83
```
84
85
## PyTorch/JAX Ecosystem
86
87
### Pyro
88
89
```python { .api }
90
def from_pyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:
91
"""
92
Convert Pyro MCMC results to InferenceData.
93
94
Args:
95
posterior (dict): Dictionary of posterior samples from Pyro MCMC
96
prior (dict, optional): Dictionary of prior samples
97
posterior_predictive (dict, optional): Dictionary of posterior predictive samples
98
observed_data (dict, optional): Dictionary of observed data
99
**kwargs: Additional conversion parameters (coords, dims, etc.)
100
101
Returns:
102
InferenceData: Converted inference data object
103
"""
104
```
105
106
### NumPyro
107
108
```python { .api }
109
def from_numpyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:
110
"""
111
Convert NumPyro MCMC results to InferenceData.
112
113
Args:
114
posterior (dict): Dictionary of posterior samples from NumPyro MCMC
115
prior (dict, optional): Dictionary of prior samples
116
posterior_predictive (dict, optional): Dictionary of posterior predictive samples
117
observed_data (dict, optional): Dictionary of observed data
118
**kwargs: Additional conversion parameters (coords, dims, etc.)
119
120
Returns:
121
InferenceData: Converted inference data object
122
"""
123
```
124
125
### Usage Examples
126
127
```python
128
import jax
129
import numpyro
130
import numpyro.distributions as dist
131
from numpyro.infer import MCMC, NUTS
132
133
# NumPyro example
134
def model(y):
135
mu = numpyro.sample("mu", dist.Normal(0, 1))
136
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
137
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
138
139
# Run MCMC
140
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2000)
141
mcmc.run(jax.random.PRNGKey(0), y=data)
142
143
# Convert to ArviZ
144
idata = az.from_numpyro(
145
mcmc,
146
observed_data={"y": data},
147
coords={"obs": range(len(data))}
148
)
149
150
# Pyro example (similar pattern)
151
import pyro
152
import torch
153
154
# After running Pyro MCMC
155
posterior_samples = mcmc.get_samples()
156
idata = az.from_pyro(
157
posterior_samples,
158
observed_data={"y": data}
159
)
160
```
161
162
## Other Frameworks
163
164
### emcee
165
166
```python { .api }
167
def from_emcee(sampler, *, var_names: list = None, slices: slice = None, **kwargs) -> InferenceData:
168
"""
169
Convert emcee ensemble sampler results to InferenceData.
170
171
Args:
172
sampler: emcee EnsembleSampler object
173
var_names (list, optional): Variable names for parameters
174
slices (slice, optional): Slice object for chain selection
175
**kwargs: Additional conversion parameters (coords, dims, etc.)
176
177
Returns:
178
InferenceData: Converted inference data object
179
"""
180
```
181
182
### PyJAGS
183
184
```python { .api }
185
def from_pyjags(fit, *, var_names: list = None, **kwargs) -> InferenceData:
186
"""
187
Convert PyJAGS fit results to InferenceData.
188
189
Args:
190
fit: PyJAGS fit object
191
var_names (list, optional): Variable names to extract
192
**kwargs: Additional conversion parameters
193
194
Returns:
195
InferenceData: Converted inference data object
196
"""
197
```
198
199
### Bean Machine
200
201
```python { .api }
202
def from_beanmachine(beanmachine_model, *, observed_data: dict = None, **kwargs) -> InferenceData:
203
"""
204
Convert Bean Machine model results to InferenceData.
205
206
Args:
207
beanmachine_model: Bean Machine model object with samples
208
observed_data (dict, optional): Dictionary of observed data
209
**kwargs: Additional conversion parameters
210
211
Returns:
212
InferenceData: Converted inference data object
213
"""
214
```
215
216
### Usage Examples
217
218
```python
219
import emcee
220
import numpy as np
221
222
# emcee example
223
def log_prob(theta):
224
return -0.5 * np.sum(theta**2)
225
226
# Run emcee sampler
227
nwalkers, ndim = 32, 5
228
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
229
sampler.run_mcmc(np.random.randn(nwalkers, ndim), 1000)
230
231
# Convert to ArviZ
232
idata = az.from_emcee(
233
sampler,
234
var_names=["param_1", "param_2", "param_3", "param_4", "param_5"]
235
)
236
```
237
238
## Generic Conversions
239
240
### Dictionary-based Conversion
241
242
```python { .api }
243
def from_dict(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, sample_stats: dict = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, log_likelihood: dict = None, **kwargs) -> InferenceData:
244
"""
245
Convert dictionary of arrays to InferenceData.
246
247
Args:
248
posterior (dict): Dictionary of posterior samples (var_name -> array)
249
prior (dict, optional): Dictionary of prior samples
250
posterior_predictive (dict, optional): Dictionary of posterior predictive samples
251
sample_stats (dict, optional): Dictionary of MCMC diagnostics
252
observed_data (dict, optional): Dictionary of observed data
253
constant_data (dict, optional): Dictionary of constant data
254
predictions (dict, optional): Dictionary of out-of-sample predictions
255
log_likelihood (dict, optional): Dictionary of log likelihood values
256
**kwargs: Additional conversion parameters (coords, dims, etc.)
257
258
Returns:
259
InferenceData: Converted inference data object
260
"""
261
```
262
263
### PyTree Conversion
264
265
```python { .api }
266
def from_pytree(posterior, *, prior = None, posterior_predictive = None, **kwargs) -> InferenceData:
267
"""
268
Convert pytree structure to InferenceData.
269
270
Args:
271
posterior: Pytree structure with posterior samples (JAX, PyTorch, etc.)
272
prior (optional): Pytree structure with prior samples
273
posterior_predictive (optional): Pytree structure with posterior predictive samples
274
**kwargs: Additional conversion parameters
275
276
Returns:
277
InferenceData: Converted inference data object
278
"""
279
```
280
281
### Usage Examples
282
283
```python
284
# Dictionary conversion
285
posterior_dict = {
286
"mu": np.random.normal(0, 1, (4, 1000)), # 4 chains, 1000 draws
287
"sigma": np.random.lognormal(0, 0.5, (4, 1000))
288
}
289
290
sample_stats_dict = {
291
"diverging": np.random.binomial(1, 0.01, (4, 1000)),
292
"energy": np.random.normal(0, 1, (4, 1000))
293
}
294
295
idata = az.from_dict(
296
posterior=posterior_dict,
297
sample_stats=sample_stats_dict,
298
observed_data={"y": y_observed},
299
coords={"chain": range(4), "draw": range(1000)}
300
)
301
302
# PyTree conversion (JAX example)
303
import jax.numpy as jnp
304
305
pytree_posterior = {
306
"mu": jnp.array(np.random.normal(0, 1, (4, 1000))),
307
"nested": {
308
"sigma": jnp.array(np.random.lognormal(0, 0.5, (4, 1000)))
309
}
310
}
311
312
idata = az.from_pytree(pytree_posterior, coords={"chain": range(4)})
313
```
314
315
## Sampling Wrappers
316
317
ArviZ provides sampling wrapper classes for consistent interfaces across frameworks:
318
319
```python { .api }
320
class SamplingWrapper:
321
"""Base class for sampling wrappers."""
322
323
class PyStanSamplingWrapper(SamplingWrapper):
324
"""Sampling wrapper for PyStan 3.x."""
325
326
class PyStan2SamplingWrapper(SamplingWrapper):
327
"""Sampling wrapper for PyStan 2.x."""
328
329
class CmdStanPySamplingWrapper(SamplingWrapper):
330
"""Sampling wrapper for CmdStanPy."""
331
332
class PyMCSamplingWrapper(SamplingWrapper):
333
"""Sampling wrapper for PyMC."""
334
```
335
336
## Conversion Best Practices
337
338
### Coordinate and Dimension Specifications
339
340
```python
341
# Specify coordinates for better data organization
342
coords = {
343
"school": ["A", "B", "C", "D", "E", "F", "G", "H"],
344
"obs": range(len(observations))
345
}
346
347
# Specify dimensions for proper array broadcasting
348
dims = {
349
"theta": ["school"],
350
"y": ["obs"]
351
}
352
353
idata = az.from_dict(
354
posterior=posterior_dict,
355
observed_data=observed_dict,
356
coords=coords,
357
dims=dims
358
)
359
```
360
361
### Handling Multiple Data Groups
362
363
```python
364
# Complete data conversion with all groups
365
idata = az.from_dict(
366
posterior=posterior_samples, # Required
367
prior=prior_samples, # Optional
368
posterior_predictive=pp_samples, # Optional
369
sample_stats=diagnostics, # Optional (divergences, energy, etc.)
370
observed_data={"y": y_obs}, # Optional but recommended
371
constant_data={"N": len(y_obs)}, # Optional
372
predictions=out_of_sample_preds, # Optional
373
log_likelihood=ll_values, # Optional (for model comparison)
374
coords=coords,
375
dims=dims
376
)
377
```
378
379
### Framework-Specific Tips
380
381
- **Stan**: Always include `observed_data` for posterior predictive checks
382
- **Pyro/NumPyro**: Use `coords` and `dims` for multi-dimensional parameters
383
- **emcee**: Provide meaningful `var_names` for parameter identification
384
- **Custom frameworks**: Use `from_dict()` with proper coordinate specifications
385
386
## Sampling Wrappers
387
388
ArviZ provides sampling wrapper classes that standardize the interface across different probabilistic programming frameworks for consistent model fitting and data conversion.
389
390
### Base Wrapper Class
391
392
```python { .api }
393
class SamplingWrapper:
394
"""
395
Base class for probabilistic programming framework sampling wrappers.
396
397
Provides a unified interface for model compilation, sampling,
398
and automatic conversion to ArviZ InferenceData format across
399
different Bayesian inference libraries.
400
401
This abstract base class defines the common interface that all
402
framework-specific wrappers should implement.
403
"""
404
405
def __init__(self, model, **kwargs):
406
"""Initialize sampling wrapper with model."""
407
408
def sample(self, **sample_kwargs):
409
"""Run MCMC sampling and return InferenceData."""
410
411
def compile_model(self, **compile_kwargs):
412
"""Compile model for sampling (if required by framework)."""
413
414
def to_inference_data(self, **conversion_kwargs):
415
"""Convert sampling results to InferenceData format."""
416
```
417
418
### Stan Ecosystem Wrappers
419
420
```python { .api }
421
class PyStanSamplingWrapper(SamplingWrapper):
422
"""
423
Sampling wrapper for PyStan 3.x (current version).
424
425
Provides unified interface for PyStan model compilation,
426
MCMC sampling, and automatic conversion to InferenceData.
427
428
Handles Stan model compilation, data preparation, sampling
429
configuration, and result extraction with proper error handling.
430
"""
431
432
def __init__(self, model_code: str = None, model_file: str = None, **kwargs):
433
"""
434
Initialize PyStan wrapper.
435
436
Args:
437
model_code (str, optional): Stan model code as string
438
model_file (str, optional): Path to .stan model file
439
**kwargs: Additional PyStan compilation parameters
440
"""
441
442
def sample(self, data: dict, *, num_chains: int = 4, num_samples: int = 1000, **kwargs):
443
"""
444
Run MCMC sampling with PyStan.
445
446
Args:
447
data (dict): Data dictionary for Stan model
448
num_chains (int): Number of MCMC chains (default 4)
449
num_samples (int): Number of samples per chain (default 1000)
450
**kwargs: Additional sampling parameters
451
452
Returns:
453
InferenceData: ArviZ inference data object
454
"""
455
456
class PyStan2SamplingWrapper(SamplingWrapper):
457
"""
458
Sampling wrapper for PyStan 2.x (legacy version).
459
460
Maintains compatibility with older PyStan 2.x installations
461
while providing the same unified sampling interface.
462
463
Note: PyStan 2.x is legacy. Consider upgrading to PyStan 3.x or CmdStanPy.
464
"""
465
466
def __init__(self, model_code: str = None, model_file: str = None, **kwargs):
467
"""Initialize PyStan 2.x wrapper."""
468
469
def sample(self, data: dict = None, **kwargs):
470
"""Run MCMC sampling with PyStan 2.x."""
471
472
class CmdStanPySamplingWrapper(SamplingWrapper):
473
"""
474
Sampling wrapper for CmdStanPy (recommended Stan interface).
475
476
Provides interface for CmdStanPy, the official Python interface
477
to CmdStan. Offers better performance and more features than PyStan.
478
479
Supports MCMC sampling, variational inference, and optimization
480
with automatic conversion to ArviZ format.
481
"""
482
483
def __init__(self, stan_file: str, **kwargs):
484
"""
485
Initialize CmdStanPy wrapper.
486
487
Args:
488
stan_file (str): Path to .stan model file
489
**kwargs: CmdStanModel compilation parameters
490
"""
491
492
def sample(self, data: dict = None, *, chains: int = 4, iter_sampling: int = 1000, **kwargs):
493
"""
494
Run MCMC sampling with CmdStanPy.
495
496
Args:
497
data (dict, optional): Data dictionary for Stan model
498
chains (int): Number of MCMC chains (default 4)
499
iter_sampling (int): Number of sampling iterations (default 1000)
500
**kwargs: Additional CmdStanPy sampling parameters
501
502
Returns:
503
InferenceData: ArviZ inference data object
504
"""
505
506
def variational(self, data: dict = None, **kwargs):
507
"""Run variational inference with CmdStanPy."""
508
509
def optimize(self, data: dict = None, **kwargs):
510
"""Run optimization with CmdStanPy."""
511
```
512
513
### PyMC Wrapper
514
515
```python { .api }
516
class PyMCSamplingWrapper(SamplingWrapper):
517
"""
518
Sampling wrapper for PyMC (formerly PyMC3).
519
520
Provides unified interface for PyMC model context management,
521
MCMC sampling with NUTS, and automatic conversion to ArviZ.
522
523
Handles PyMC model contexts, prior predictive sampling,
524
posterior predictive sampling, and comprehensive diagnostics.
525
"""
526
527
def __init__(self, model_context, **kwargs):
528
"""
529
Initialize PyMC wrapper.
530
531
Args:
532
model_context: PyMC model context or model object
533
**kwargs: Additional PyMC configuration parameters
534
"""
535
536
def sample(self, *, draws: int = 1000, tune: int = 1000, chains: int = 4, **kwargs):
537
"""
538
Run MCMC sampling with PyMC.
539
540
Args:
541
draws (int): Number of samples to draw (default 1000)
542
tune (int): Number of tuning samples (default 1000)
543
chains (int): Number of MCMC chains (default 4)
544
**kwargs: Additional PyMC sampling parameters (nuts_sampler, etc.)
545
546
Returns:
547
InferenceData: ArviZ inference data object with all groups
548
"""
549
550
def sample_prior_predictive(self, samples: int = 500, **kwargs):
551
"""Sample from prior predictive distribution."""
552
553
def sample_posterior_predictive(self, trace, samples: int = 500, **kwargs):
554
"""Sample from posterior predictive distribution."""
555
```
556
557
### Usage Examples
558
559
```python
560
# CmdStanPy wrapper usage
561
wrapper = az.CmdStanPySamplingWrapper("my_model.stan")
562
563
# Prepare data
564
data = {
565
"N": len(y_obs),
566
"y": y_obs,
567
"x": x_data
568
}
569
570
# Run sampling with automatic conversion
571
idata = wrapper.sample(
572
data=data,
573
chains=4,
574
iter_sampling=2000,
575
iter_warmup=1000
576
)
577
578
# Data is automatically converted to InferenceData
579
print(f"Posterior samples: {idata.posterior.dims}")
580
print(f"Sample stats: {list(idata.sample_stats.data_vars)}")
581
582
# PyMC wrapper usage
583
import pymc as pm
584
585
with pm.Model() as model:
586
mu = pm.Normal("mu", mu=0, sigma=1)
587
sigma = pm.HalfNormal("sigma", sigma=1)
588
y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)
589
590
wrapper = az.PyMCSamplingWrapper(model)
591
idata = wrapper.sample(draws=1000, tune=1000, chains=4)
592
593
# Includes prior and posterior predictive samples automatically
594
print(f"Groups: {list(idata.groups())}")
595
596
# PyStan wrapper usage
597
model_code = """
598
data {
599
int<lower=0> N;
600
vector[N] y;
601
}
602
parameters {
603
real mu;
604
real<lower=0> sigma;
605
}
606
model {
607
mu ~ normal(0, 1);
608
sigma ~ half_normal(1);
609
y ~ normal(mu, sigma);
610
}
611
"""
612
613
wrapper = az.PyStanSamplingWrapper(model_code=model_code)
614
idata = wrapper.sample(
615
data={"N": len(y_obs), "y": y_obs},
616
num_chains=4,
617
num_samples=1000
618
)
619
```
620
621
### Wrapper Configuration
622
623
```python
624
# Common configuration patterns across wrappers
625
config = {
626
"chains": 4,
627
"cores": 4, # Parallel chain execution
628
"progress_bar": True,
629
"return_inferencedata": True, # Default for all wrappers
630
}
631
632
# Framework-specific configurations
633
cmdstanpy_config = {
634
**config,
635
"iter_sampling": 1000,
636
"iter_warmup": 1000,
637
"adapt_delta": 0.8, # NUTS tuning parameter
638
"max_treedepth": 10
639
}
640
641
pymc_config = {
642
**config,
643
"draws": 1000,
644
"tune": 1000,
645
"target_accept": 0.8,
646
"nuts_sampler": "nutpie" # Alternative sampler
647
}
648
649
# Use with wrappers
650
cmdstan_wrapper = az.CmdStanPySamplingWrapper("model.stan")
651
idata = cmdstan_wrapper.sample(data=data, **cmdstanpy_config)
652
653
pymc_wrapper = az.PyMCSamplingWrapper(pymc_model)
654
idata = pymc_wrapper.sample(**pymc_config)
655
```
656
657
### Wrapper Benefits
658
659
1. **Unified Interface**: Same API across different frameworks
660
2. **Automatic Conversion**: Results always returned as InferenceData
661
3. **Error Handling**: Consistent error messages and troubleshooting
662
4. **Best Practices**: Built-in recommendations for sampling parameters
663
5. **Extensibility**: Easy to add support for new frameworks
664
665
```python
666
# Compare results across frameworks easily
667
frameworks = {
668
"cmdstanpy": az.CmdStanPySamplingWrapper("model.stan"),
669
"pymc": az.PyMCSamplingWrapper(pymc_model),
670
"pystan": az.PyStanSamplingWrapper(model_code=stan_code)
671
}
672
673
results = {}
674
for name, wrapper in frameworks.items():
675
results[name] = wrapper.sample(data=data, chains=4)
676
677
# All results are InferenceData objects - easy comparison
678
comparison = az.compare(results)
679
print(comparison)
680
```