or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

configuration-management.mddata-operations.mdframework-integrations.mdindex.mdperformance-utilities.mdstatistical-analysis.mdvisualization-plotting.md

framework-integrations.mddocs/

0

# Framework Integrations

1

2

Convert inference results from various probabilistic programming frameworks to ArviZ's unified InferenceData format. Supports Stan (CmdStan, PyStan, CmdStanPy), PyMC, Pyro, NumPyro, JAX, emcee, and more.

3

4

## Stan Ecosystem

5

6

### CmdStan and CmdStanPy

7

8

```python { .api }

9

def from_cmdstan(posterior: str = None, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, **kwargs) -> InferenceData:

10

"""

11

Convert CmdStan output files to InferenceData.

12

13

Args:

14

posterior (str, optional): Path to posterior samples CSV file

15

posterior_predictive (str, optional): Path to posterior predictive CSV

16

observed_data (dict, optional): Dictionary of observed data

17

constant_data (dict, optional): Dictionary of constant/fixed data

18

predictions (dict, optional): Dictionary of out-of-sample predictions

19

**kwargs: Additional conversion parameters (coords, dims, etc.)

20

21

Returns:

22

InferenceData: Converted inference data object

23

"""

24

25

def from_cmdstanpy(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:

26

"""

27

Convert CmdStanPy fit results to InferenceData.

28

29

Args:

30

fit: CmdStanPy fit object (CmdStanMCMC, CmdStanMLE, CmdStanVB)

31

posterior_predictive (str, optional): Variable name for posterior predictive

32

observed_data (dict, optional): Dictionary of observed data

33

constant_data (dict, optional): Dictionary of constant data

34

**kwargs: Additional conversion parameters

35

36

Returns:

37

InferenceData: Converted inference data object

38

"""

39

```

40

41

### PyStan

42

43

```python { .api }

44

def from_pystan(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:

45

"""

46

Convert PyStan fit results to InferenceData.

47

48

Args:

49

fit: PyStan fit object (StanFit4Model)

50

posterior_predictive (str, optional): Variable name for posterior predictive

51

observed_data (dict, optional): Dictionary of observed data

52

constant_data (dict, optional): Dictionary of constant data

53

**kwargs: Additional conversion parameters (coords, dims, etc.)

54

55

Returns:

56

InferenceData: Converted inference data object

57

"""

58

```

59

60

### Usage Examples

61

62

```python

63

import arviz as az

64

import cmdstanpy

65

66

# CmdStanPy example

67

model = cmdstanpy.CmdStanModel(stan_file="model.stan")

68

fit = model.sample(data=data_dict)

69

idata = az.from_cmdstanpy(fit, observed_data={"y": y_obs})

70

71

# CmdStan CSV files

72

idata = az.from_cmdstan(

73

posterior="output.csv",

74

posterior_predictive="predictions.csv",

75

observed_data={"y": y_obs}

76

)

77

78

# PyStan example (legacy)

79

import pystan

80

model = pystan.StanModel(file="model.stan")

81

fit = model.sampling(data=data_dict)

82

idata = az.from_pystan(fit, observed_data={"y": y_obs})

83

```

84

85

## PyTorch/JAX Ecosystem

86

87

### Pyro

88

89

```python { .api }

90

def from_pyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:

91

"""

92

Convert Pyro MCMC results to InferenceData.

93

94

Args:

95

posterior (dict): Dictionary of posterior samples from Pyro MCMC

96

prior (dict, optional): Dictionary of prior samples

97

posterior_predictive (dict, optional): Dictionary of posterior predictive samples

98

observed_data (dict, optional): Dictionary of observed data

99

**kwargs: Additional conversion parameters (coords, dims, etc.)

100

101

Returns:

102

InferenceData: Converted inference data object

103

"""

104

```

105

106

### NumPyro

107

108

```python { .api }

109

def from_numpyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:

110

"""

111

Convert NumPyro MCMC results to InferenceData.

112

113

Args:

114

posterior (dict): Dictionary of posterior samples from NumPyro MCMC

115

prior (dict, optional): Dictionary of prior samples

116

posterior_predictive (dict, optional): Dictionary of posterior predictive samples

117

observed_data (dict, optional): Dictionary of observed data

118

**kwargs: Additional conversion parameters (coords, dims, etc.)

119

120

Returns:

121

InferenceData: Converted inference data object

122

"""

123

```

124

125

### Usage Examples

126

127

```python

128

import jax

129

import numpyro

130

import numpyro.distributions as dist

131

from numpyro.infer import MCMC, NUTS

132

133

# NumPyro example

134

def model(y):

135

mu = numpyro.sample("mu", dist.Normal(0, 1))

136

sigma = numpyro.sample("sigma", dist.HalfNormal(1))

137

numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

138

139

# Run MCMC

140

mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2000)

141

mcmc.run(jax.random.PRNGKey(0), y=data)

142

143

# Convert to ArviZ

144

idata = az.from_numpyro(

145

mcmc,

146

observed_data={"y": data},

147

coords={"obs": range(len(data))}

148

)

149

150

# Pyro example (similar pattern)

151

import pyro

152

import torch

153

154

# After running Pyro MCMC

155

posterior_samples = mcmc.get_samples()

156

idata = az.from_pyro(

157

posterior_samples,

158

observed_data={"y": data}

159

)

160

```

161

162

## Other Frameworks

163

164

### emcee

165

166

```python { .api }

167

def from_emcee(sampler, *, var_names: list = None, slices: slice = None, **kwargs) -> InferenceData:

168

"""

169

Convert emcee ensemble sampler results to InferenceData.

170

171

Args:

172

sampler: emcee EnsembleSampler object

173

var_names (list, optional): Variable names for parameters

174

slices (slice, optional): Slice object for chain selection

175

**kwargs: Additional conversion parameters (coords, dims, etc.)

176

177

Returns:

178

InferenceData: Converted inference data object

179

"""

180

```

181

182

### PyJAGS

183

184

```python { .api }

185

def from_pyjags(fit, *, var_names: list = None, **kwargs) -> InferenceData:

186

"""

187

Convert PyJAGS fit results to InferenceData.

188

189

Args:

190

fit: PyJAGS fit object

191

var_names (list, optional): Variable names to extract

192

**kwargs: Additional conversion parameters

193

194

Returns:

195

InferenceData: Converted inference data object

196

"""

197

```

198

199

### Bean Machine

200

201

```python { .api }

202

def from_beanmachine(beanmachine_model, *, observed_data: dict = None, **kwargs) -> InferenceData:

203

"""

204

Convert Bean Machine model results to InferenceData.

205

206

Args:

207

beanmachine_model: Bean Machine model object with samples

208

observed_data (dict, optional): Dictionary of observed data

209

**kwargs: Additional conversion parameters

210

211

Returns:

212

InferenceData: Converted inference data object

213

"""

214

```

215

216

### Usage Examples

217

218

```python

219

import emcee

220

import numpy as np

221

222

# emcee example

223

def log_prob(theta):

224

return -0.5 * np.sum(theta**2)

225

226

# Run emcee sampler

227

nwalkers, ndim = 32, 5

228

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)

229

sampler.run_mcmc(np.random.randn(nwalkers, ndim), 1000)

230

231

# Convert to ArviZ

232

idata = az.from_emcee(

233

sampler,

234

var_names=["param_1", "param_2", "param_3", "param_4", "param_5"]

235

)

236

```

237

238

## Generic Conversions

239

240

### Dictionary-based Conversion

241

242

```python { .api }

243

def from_dict(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, sample_stats: dict = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, log_likelihood: dict = None, **kwargs) -> InferenceData:

244

"""

245

Convert dictionary of arrays to InferenceData.

246

247

Args:

248

posterior (dict): Dictionary of posterior samples (var_name -> array)

249

prior (dict, optional): Dictionary of prior samples

250

posterior_predictive (dict, optional): Dictionary of posterior predictive samples

251

sample_stats (dict, optional): Dictionary of MCMC diagnostics

252

observed_data (dict, optional): Dictionary of observed data

253

constant_data (dict, optional): Dictionary of constant data

254

predictions (dict, optional): Dictionary of out-of-sample predictions

255

log_likelihood (dict, optional): Dictionary of log likelihood values

256

**kwargs: Additional conversion parameters (coords, dims, etc.)

257

258

Returns:

259

InferenceData: Converted inference data object

260

"""

261

```

262

263

### PyTree Conversion

264

265

```python { .api }

266

def from_pytree(posterior, *, prior = None, posterior_predictive = None, **kwargs) -> InferenceData:

267

"""

268

Convert pytree structure to InferenceData.

269

270

Args:

271

posterior: Pytree structure with posterior samples (JAX, PyTorch, etc.)

272

prior (optional): Pytree structure with prior samples

273

posterior_predictive (optional): Pytree structure with posterior predictive samples

274

**kwargs: Additional conversion parameters

275

276

Returns:

277

InferenceData: Converted inference data object

278

"""

279

```

280

281

### Usage Examples

282

283

```python

284

# Dictionary conversion

285

posterior_dict = {

286

"mu": np.random.normal(0, 1, (4, 1000)), # 4 chains, 1000 draws

287

"sigma": np.random.lognormal(0, 0.5, (4, 1000))

288

}

289

290

sample_stats_dict = {

291

"diverging": np.random.binomial(1, 0.01, (4, 1000)),

292

"energy": np.random.normal(0, 1, (4, 1000))

293

}

294

295

idata = az.from_dict(

296

posterior=posterior_dict,

297

sample_stats=sample_stats_dict,

298

observed_data={"y": y_observed},

299

coords={"chain": range(4), "draw": range(1000)}

300

)

301

302

# PyTree conversion (JAX example)

303

import jax.numpy as jnp

304

305

pytree_posterior = {

306

"mu": jnp.array(np.random.normal(0, 1, (4, 1000))),

307

"nested": {

308

"sigma": jnp.array(np.random.lognormal(0, 0.5, (4, 1000)))

309

}

310

}

311

312

idata = az.from_pytree(pytree_posterior, coords={"chain": range(4)})

313

```

314

315

## Sampling Wrappers

316

317

ArviZ provides sampling wrapper classes for consistent interfaces across frameworks:

318

319

```python { .api }

320

class SamplingWrapper:

321

"""Base class for sampling wrappers."""

322

323

class PyStanSamplingWrapper(SamplingWrapper):

324

"""Sampling wrapper for PyStan 3.x."""

325

326

class PyStan2SamplingWrapper(SamplingWrapper):

327

"""Sampling wrapper for PyStan 2.x."""

328

329

class CmdStanPySamplingWrapper(SamplingWrapper):

330

"""Sampling wrapper for CmdStanPy."""

331

332

class PyMCSamplingWrapper(SamplingWrapper):

333

"""Sampling wrapper for PyMC."""

334

```

335

336

## Conversion Best Practices

337

338

### Coordinate and Dimension Specifications

339

340

```python

341

# Specify coordinates for better data organization

342

coords = {

343

"school": ["A", "B", "C", "D", "E", "F", "G", "H"],

344

"obs": range(len(observations))

345

}

346

347

# Specify dimensions for proper array broadcasting

348

dims = {

349

"theta": ["school"],

350

"y": ["obs"]

351

}

352

353

idata = az.from_dict(

354

posterior=posterior_dict,

355

observed_data=observed_dict,

356

coords=coords,

357

dims=dims

358

)

359

```

360

361

### Handling Multiple Data Groups

362

363

```python

364

# Complete data conversion with all groups

365

idata = az.from_dict(

366

posterior=posterior_samples, # Required

367

prior=prior_samples, # Optional

368

posterior_predictive=pp_samples, # Optional

369

sample_stats=diagnostics, # Optional (divergences, energy, etc.)

370

observed_data={"y": y_obs}, # Optional but recommended

371

constant_data={"N": len(y_obs)}, # Optional

372

predictions=out_of_sample_preds, # Optional

373

log_likelihood=ll_values, # Optional (for model comparison)

374

coords=coords,

375

dims=dims

376

)

377

```

378

379

### Framework-Specific Tips

380

381

- **Stan**: Always include `observed_data` for posterior predictive checks

382

- **Pyro/NumPyro**: Use `coords` and `dims` for multi-dimensional parameters

383

- **emcee**: Provide meaningful `var_names` for parameter identification

384

- **Custom frameworks**: Use `from_dict()` with proper coordinate specifications

385

386

## Sampling Wrappers

387

388

ArviZ provides sampling wrapper classes that standardize the interface across different probabilistic programming frameworks for consistent model fitting and data conversion.

389

390

### Base Wrapper Class

391

392

```python { .api }

393

class SamplingWrapper:

394

"""

395

Base class for probabilistic programming framework sampling wrappers.

396

397

Provides a unified interface for model compilation, sampling,

398

and automatic conversion to ArviZ InferenceData format across

399

different Bayesian inference libraries.

400

401

This abstract base class defines the common interface that all

402

framework-specific wrappers should implement.

403

"""

404

405

def __init__(self, model, **kwargs):

406

"""Initialize sampling wrapper with model."""

407

408

def sample(self, **sample_kwargs):

409

"""Run MCMC sampling and return InferenceData."""

410

411

def compile_model(self, **compile_kwargs):

412

"""Compile model for sampling (if required by framework)."""

413

414

def to_inference_data(self, **conversion_kwargs):

415

"""Convert sampling results to InferenceData format."""

416

```

417

418

### Stan Ecosystem Wrappers

419

420

```python { .api }

421

class PyStanSamplingWrapper(SamplingWrapper):

422

"""

423

Sampling wrapper for PyStan 3.x (current version).

424

425

Provides unified interface for PyStan model compilation,

426

MCMC sampling, and automatic conversion to InferenceData.

427

428

Handles Stan model compilation, data preparation, sampling

429

configuration, and result extraction with proper error handling.

430

"""

431

432

def __init__(self, model_code: str = None, model_file: str = None, **kwargs):

433

"""

434

Initialize PyStan wrapper.

435

436

Args:

437

model_code (str, optional): Stan model code as string

438

model_file (str, optional): Path to .stan model file

439

**kwargs: Additional PyStan compilation parameters

440

"""

441

442

def sample(self, data: dict, *, num_chains: int = 4, num_samples: int = 1000, **kwargs):

443

"""

444

Run MCMC sampling with PyStan.

445

446

Args:

447

data (dict): Data dictionary for Stan model

448

num_chains (int): Number of MCMC chains (default 4)

449

num_samples (int): Number of samples per chain (default 1000)

450

**kwargs: Additional sampling parameters

451

452

Returns:

453

InferenceData: ArviZ inference data object

454

"""

455

456

class PyStan2SamplingWrapper(SamplingWrapper):

457

"""

458

Sampling wrapper for PyStan 2.x (legacy version).

459

460

Maintains compatibility with older PyStan 2.x installations

461

while providing the same unified sampling interface.

462

463

Note: PyStan 2.x is legacy. Consider upgrading to PyStan 3.x or CmdStanPy.

464

"""

465

466

def __init__(self, model_code: str = None, model_file: str = None, **kwargs):

467

"""Initialize PyStan 2.x wrapper."""

468

469

def sample(self, data: dict = None, **kwargs):

470

"""Run MCMC sampling with PyStan 2.x."""

471

472

class CmdStanPySamplingWrapper(SamplingWrapper):

473

"""

474

Sampling wrapper for CmdStanPy (recommended Stan interface).

475

476

Provides interface for CmdStanPy, the official Python interface

477

to CmdStan. Offers better performance and more features than PyStan.

478

479

Supports MCMC sampling, variational inference, and optimization

480

with automatic conversion to ArviZ format.

481

"""

482

483

def __init__(self, stan_file: str, **kwargs):

484

"""

485

Initialize CmdStanPy wrapper.

486

487

Args:

488

stan_file (str): Path to .stan model file

489

**kwargs: CmdStanModel compilation parameters

490

"""

491

492

def sample(self, data: dict = None, *, chains: int = 4, iter_sampling: int = 1000, **kwargs):

493

"""

494

Run MCMC sampling with CmdStanPy.

495

496

Args:

497

data (dict, optional): Data dictionary for Stan model

498

chains (int): Number of MCMC chains (default 4)

499

iter_sampling (int): Number of sampling iterations (default 1000)

500

**kwargs: Additional CmdStanPy sampling parameters

501

502

Returns:

503

InferenceData: ArviZ inference data object

504

"""

505

506

def variational(self, data: dict = None, **kwargs):

507

"""Run variational inference with CmdStanPy."""

508

509

def optimize(self, data: dict = None, **kwargs):

510

"""Run optimization with CmdStanPy."""

511

```

512

513

### PyMC Wrapper

514

515

```python { .api }

516

class PyMCSamplingWrapper(SamplingWrapper):

517

"""

518

Sampling wrapper for PyMC (formerly PyMC3).

519

520

Provides unified interface for PyMC model context management,

521

MCMC sampling with NUTS, and automatic conversion to ArviZ.

522

523

Handles PyMC model contexts, prior predictive sampling,

524

posterior predictive sampling, and comprehensive diagnostics.

525

"""

526

527

def __init__(self, model_context, **kwargs):

528

"""

529

Initialize PyMC wrapper.

530

531

Args:

532

model_context: PyMC model context or model object

533

**kwargs: Additional PyMC configuration parameters

534

"""

535

536

def sample(self, *, draws: int = 1000, tune: int = 1000, chains: int = 4, **kwargs):

537

"""

538

Run MCMC sampling with PyMC.

539

540

Args:

541

draws (int): Number of samples to draw (default 1000)

542

tune (int): Number of tuning samples (default 1000)

543

chains (int): Number of MCMC chains (default 4)

544

**kwargs: Additional PyMC sampling parameters (nuts_sampler, etc.)

545

546

Returns:

547

InferenceData: ArviZ inference data object with all groups

548

"""

549

550

def sample_prior_predictive(self, samples: int = 500, **kwargs):

551

"""Sample from prior predictive distribution."""

552

553

def sample_posterior_predictive(self, trace, samples: int = 500, **kwargs):

554

"""Sample from posterior predictive distribution."""

555

```

556

557

### Usage Examples

558

559

```python

560

# CmdStanPy wrapper usage

561

wrapper = az.CmdStanPySamplingWrapper("my_model.stan")

562

563

# Prepare data

564

data = {

565

"N": len(y_obs),

566

"y": y_obs,

567

"x": x_data

568

}

569

570

# Run sampling with automatic conversion

571

idata = wrapper.sample(

572

data=data,

573

chains=4,

574

iter_sampling=2000,

575

iter_warmup=1000

576

)

577

578

# Data is automatically converted to InferenceData

579

print(f"Posterior samples: {idata.posterior.dims}")

580

print(f"Sample stats: {list(idata.sample_stats.data_vars)}")

581

582

# PyMC wrapper usage

583

import pymc as pm

584

585

with pm.Model() as model:

586

mu = pm.Normal("mu", mu=0, sigma=1)

587

sigma = pm.HalfNormal("sigma", sigma=1)

588

y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)

589

590

wrapper = az.PyMCSamplingWrapper(model)

591

idata = wrapper.sample(draws=1000, tune=1000, chains=4)

592

593

# Includes prior and posterior predictive samples automatically

594

print(f"Groups: {list(idata.groups())}")

595

596

# PyStan wrapper usage

597

model_code = """

598

data {

599

int<lower=0> N;

600

vector[N] y;

601

}

602

parameters {

603

real mu;

604

real<lower=0> sigma;

605

}

606

model {

607

mu ~ normal(0, 1);

608

sigma ~ half_normal(1);

609

y ~ normal(mu, sigma);

610

}

611

"""

612

613

wrapper = az.PyStanSamplingWrapper(model_code=model_code)

614

idata = wrapper.sample(

615

data={"N": len(y_obs), "y": y_obs},

616

num_chains=4,

617

num_samples=1000

618

)

619

```

620

621

### Wrapper Configuration

622

623

```python

624

# Common configuration patterns across wrappers

625

config = {

626

"chains": 4,

627

"cores": 4, # Parallel chain execution

628

"progress_bar": True,

629

"return_inferencedata": True, # Default for all wrappers

630

}

631

632

# Framework-specific configurations

633

cmdstanpy_config = {

634

**config,

635

"iter_sampling": 1000,

636

"iter_warmup": 1000,

637

"adapt_delta": 0.8, # NUTS tuning parameter

638

"max_treedepth": 10

639

}

640

641

pymc_config = {

642

**config,

643

"draws": 1000,

644

"tune": 1000,

645

"target_accept": 0.8,

646

"nuts_sampler": "nutpie" # Alternative sampler

647

}

648

649

# Use with wrappers

650

cmdstan_wrapper = az.CmdStanPySamplingWrapper("model.stan")

651

idata = cmdstan_wrapper.sample(data=data, **cmdstanpy_config)

652

653

pymc_wrapper = az.PyMCSamplingWrapper(pymc_model)

654

idata = pymc_wrapper.sample(**pymc_config)

655

```

656

657

### Wrapper Benefits

658

659

1. **Unified Interface**: Same API across different frameworks

660

2. **Automatic Conversion**: Results always returned as InferenceData

661

3. **Error Handling**: Consistent error messages and troubleshooting

662

4. **Best Practices**: Built-in recommendations for sampling parameters

663

5. **Extensibility**: Easy to add support for new frameworks

664

665

```python

666

# Compare results across frameworks easily

667

frameworks = {

668

"cmdstanpy": az.CmdStanPySamplingWrapper("model.stan"),

669

"pymc": az.PyMCSamplingWrapper(pymc_model),

670

"pystan": az.PyStanSamplingWrapper(model_code=stan_code)

671

}

672

673

results = {}

674

for name, wrapper in frameworks.items():

675

results[name] = wrapper.sample(data=data, chains=4)

676

677

# All results are InferenceData objects - easy comparison

678

comparison = az.compare(results)

679

print(comparison)

680

```