Exploratory analysis of Bayesian models with comprehensive data manipulation, statistical diagnostics, and visualization capabilities
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Convert inference results from various probabilistic programming frameworks to ArviZ's unified InferenceData format. Supports Stan (CmdStan, PyStan, CmdStanPy), PyMC, Pyro, NumPyro, JAX, emcee, and more.
def from_cmdstan(posterior: str = None, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, **kwargs) -> InferenceData:
"""
Convert CmdStan output files to InferenceData.
Args:
posterior (str, optional): Path to posterior samples CSV file
posterior_predictive (str, optional): Path to posterior predictive CSV
observed_data (dict, optional): Dictionary of observed data
constant_data (dict, optional): Dictionary of constant/fixed data
predictions (dict, optional): Dictionary of out-of-sample predictions
**kwargs: Additional conversion parameters (coords, dims, etc.)
Returns:
InferenceData: Converted inference data object
"""
def from_cmdstanpy(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:
"""
Convert CmdStanPy fit results to InferenceData.
Args:
fit: CmdStanPy fit object (CmdStanMCMC, CmdStanMLE, CmdStanVB)
posterior_predictive (str, optional): Variable name for posterior predictive
observed_data (dict, optional): Dictionary of observed data
constant_data (dict, optional): Dictionary of constant data
**kwargs: Additional conversion parameters
Returns:
InferenceData: Converted inference data object
"""def from_pystan(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:
"""
Convert PyStan fit results to InferenceData.
Args:
fit: PyStan fit object (StanFit4Model)
posterior_predictive (str, optional): Variable name for posterior predictive
observed_data (dict, optional): Dictionary of observed data
constant_data (dict, optional): Dictionary of constant data
**kwargs: Additional conversion parameters (coords, dims, etc.)
Returns:
InferenceData: Converted inference data object
"""import arviz as az
import cmdstanpy
# CmdStanPy example
model = cmdstanpy.CmdStanModel(stan_file="model.stan")
fit = model.sample(data=data_dict)
idata = az.from_cmdstanpy(fit, observed_data={"y": y_obs})
# CmdStan CSV files
idata = az.from_cmdstan(
posterior="output.csv",
posterior_predictive="predictions.csv",
observed_data={"y": y_obs}
)
# PyStan example (legacy)
import pystan
model = pystan.StanModel(file="model.stan")
fit = model.sampling(data=data_dict)
idata = az.from_pystan(fit, observed_data={"y": y_obs})def from_pyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:
"""
Convert Pyro MCMC results to InferenceData.
Args:
posterior (dict): Dictionary of posterior samples from Pyro MCMC
prior (dict, optional): Dictionary of prior samples
posterior_predictive (dict, optional): Dictionary of posterior predictive samples
observed_data (dict, optional): Dictionary of observed data
**kwargs: Additional conversion parameters (coords, dims, etc.)
Returns:
InferenceData: Converted inference data object
"""def from_numpyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:
"""
Convert NumPyro MCMC results to InferenceData.
Args:
posterior (dict): Dictionary of posterior samples from NumPyro MCMC
prior (dict, optional): Dictionary of prior samples
posterior_predictive (dict, optional): Dictionary of posterior predictive samples
observed_data (dict, optional): Dictionary of observed data
**kwargs: Additional conversion parameters (coords, dims, etc.)
Returns:
InferenceData: Converted inference data object
"""import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
# NumPyro example
def model(y):
mu = numpyro.sample("mu", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
# Run MCMC
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2000)
mcmc.run(jax.random.PRNGKey(0), y=data)
# Convert to ArviZ
idata = az.from_numpyro(
mcmc,
observed_data={"y": data},
coords={"obs": range(len(data))}
)
# Pyro example (similar pattern)
import pyro
import torch
# After running Pyro MCMC
posterior_samples = mcmc.get_samples()
idata = az.from_pyro(
posterior_samples,
observed_data={"y": data}
)def from_emcee(sampler, *, var_names: list = None, slices: slice = None, **kwargs) -> InferenceData:
"""
Convert emcee ensemble sampler results to InferenceData.
Args:
sampler: emcee EnsembleSampler object
var_names (list, optional): Variable names for parameters
slices (slice, optional): Slice object for chain selection
**kwargs: Additional conversion parameters (coords, dims, etc.)
Returns:
InferenceData: Converted inference data object
"""def from_pyjags(fit, *, var_names: list = None, **kwargs) -> InferenceData:
"""
Convert PyJAGS fit results to InferenceData.
Args:
fit: PyJAGS fit object
var_names (list, optional): Variable names to extract
**kwargs: Additional conversion parameters
Returns:
InferenceData: Converted inference data object
"""def from_beanmachine(beanmachine_model, *, observed_data: dict = None, **kwargs) -> InferenceData:
"""
Convert Bean Machine model results to InferenceData.
Args:
beanmachine_model: Bean Machine model object with samples
observed_data (dict, optional): Dictionary of observed data
**kwargs: Additional conversion parameters
Returns:
InferenceData: Converted inference data object
"""import emcee
import numpy as np
# emcee example
def log_prob(theta):
return -0.5 * np.sum(theta**2)
# Run emcee sampler
nwalkers, ndim = 32, 5
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
sampler.run_mcmc(np.random.randn(nwalkers, ndim), 1000)
# Convert to ArviZ
idata = az.from_emcee(
sampler,
var_names=["param_1", "param_2", "param_3", "param_4", "param_5"]
)def from_dict(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, sample_stats: dict = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, log_likelihood: dict = None, **kwargs) -> InferenceData:
"""
Convert dictionary of arrays to InferenceData.
Args:
posterior (dict): Dictionary of posterior samples (var_name -> array)
prior (dict, optional): Dictionary of prior samples
posterior_predictive (dict, optional): Dictionary of posterior predictive samples
sample_stats (dict, optional): Dictionary of MCMC diagnostics
observed_data (dict, optional): Dictionary of observed data
constant_data (dict, optional): Dictionary of constant data
predictions (dict, optional): Dictionary of out-of-sample predictions
log_likelihood (dict, optional): Dictionary of log likelihood values
**kwargs: Additional conversion parameters (coords, dims, etc.)
Returns:
InferenceData: Converted inference data object
"""def from_pytree(posterior, *, prior = None, posterior_predictive = None, **kwargs) -> InferenceData:
"""
Convert pytree structure to InferenceData.
Args:
posterior: Pytree structure with posterior samples (JAX, PyTorch, etc.)
prior (optional): Pytree structure with prior samples
posterior_predictive (optional): Pytree structure with posterior predictive samples
**kwargs: Additional conversion parameters
Returns:
InferenceData: Converted inference data object
"""# Dictionary conversion
posterior_dict = {
"mu": np.random.normal(0, 1, (4, 1000)), # 4 chains, 1000 draws
"sigma": np.random.lognormal(0, 0.5, (4, 1000))
}
sample_stats_dict = {
"diverging": np.random.binomial(1, 0.01, (4, 1000)),
"energy": np.random.normal(0, 1, (4, 1000))
}
idata = az.from_dict(
posterior=posterior_dict,
sample_stats=sample_stats_dict,
observed_data={"y": y_observed},
coords={"chain": range(4), "draw": range(1000)}
)
# PyTree conversion (JAX example)
import jax.numpy as jnp
pytree_posterior = {
"mu": jnp.array(np.random.normal(0, 1, (4, 1000))),
"nested": {
"sigma": jnp.array(np.random.lognormal(0, 0.5, (4, 1000)))
}
}
idata = az.from_pytree(pytree_posterior, coords={"chain": range(4)})ArviZ provides sampling wrapper classes for consistent interfaces across frameworks:
class SamplingWrapper:
"""Base class for sampling wrappers."""
class PyStanSamplingWrapper(SamplingWrapper):
"""Sampling wrapper for PyStan 3.x."""
class PyStan2SamplingWrapper(SamplingWrapper):
"""Sampling wrapper for PyStan 2.x."""
class CmdStanPySamplingWrapper(SamplingWrapper):
"""Sampling wrapper for CmdStanPy."""
class PyMCSamplingWrapper(SamplingWrapper):
"""Sampling wrapper for PyMC."""# Specify coordinates for better data organization
coords = {
"school": ["A", "B", "C", "D", "E", "F", "G", "H"],
"obs": range(len(observations))
}
# Specify dimensions for proper array broadcasting
dims = {
"theta": ["school"],
"y": ["obs"]
}
idata = az.from_dict(
posterior=posterior_dict,
observed_data=observed_dict,
coords=coords,
dims=dims
)# Complete data conversion with all groups
idata = az.from_dict(
posterior=posterior_samples, # Required
prior=prior_samples, # Optional
posterior_predictive=pp_samples, # Optional
sample_stats=diagnostics, # Optional (divergences, energy, etc.)
observed_data={"y": y_obs}, # Optional but recommended
constant_data={"N": len(y_obs)}, # Optional
predictions=out_of_sample_preds, # Optional
log_likelihood=ll_values, # Optional (for model comparison)
coords=coords,
dims=dims
)observed_data for posterior predictive checkscoords and dims for multi-dimensional parametersvar_names for parameter identificationfrom_dict() with proper coordinate specificationsArviZ provides sampling wrapper classes that standardize the interface across different probabilistic programming frameworks for consistent model fitting and data conversion.
class SamplingWrapper:
"""
Base class for probabilistic programming framework sampling wrappers.
Provides a unified interface for model compilation, sampling,
and automatic conversion to ArviZ InferenceData format across
different Bayesian inference libraries.
This abstract base class defines the common interface that all
framework-specific wrappers should implement.
"""
def __init__(self, model, **kwargs):
"""Initialize sampling wrapper with model."""
def sample(self, **sample_kwargs):
"""Run MCMC sampling and return InferenceData."""
def compile_model(self, **compile_kwargs):
"""Compile model for sampling (if required by framework)."""
def to_inference_data(self, **conversion_kwargs):
"""Convert sampling results to InferenceData format."""class PyStanSamplingWrapper(SamplingWrapper):
"""
Sampling wrapper for PyStan 3.x (current version).
Provides unified interface for PyStan model compilation,
MCMC sampling, and automatic conversion to InferenceData.
Handles Stan model compilation, data preparation, sampling
configuration, and result extraction with proper error handling.
"""
def __init__(self, model_code: str = None, model_file: str = None, **kwargs):
"""
Initialize PyStan wrapper.
Args:
model_code (str, optional): Stan model code as string
model_file (str, optional): Path to .stan model file
**kwargs: Additional PyStan compilation parameters
"""
def sample(self, data: dict, *, num_chains: int = 4, num_samples: int = 1000, **kwargs):
"""
Run MCMC sampling with PyStan.
Args:
data (dict): Data dictionary for Stan model
num_chains (int): Number of MCMC chains (default 4)
num_samples (int): Number of samples per chain (default 1000)
**kwargs: Additional sampling parameters
Returns:
InferenceData: ArviZ inference data object
"""
class PyStan2SamplingWrapper(SamplingWrapper):
"""
Sampling wrapper for PyStan 2.x (legacy version).
Maintains compatibility with older PyStan 2.x installations
while providing the same unified sampling interface.
Note: PyStan 2.x is legacy. Consider upgrading to PyStan 3.x or CmdStanPy.
"""
def __init__(self, model_code: str = None, model_file: str = None, **kwargs):
"""Initialize PyStan 2.x wrapper."""
def sample(self, data: dict = None, **kwargs):
"""Run MCMC sampling with PyStan 2.x."""
class CmdStanPySamplingWrapper(SamplingWrapper):
"""
Sampling wrapper for CmdStanPy (recommended Stan interface).
Provides interface for CmdStanPy, the official Python interface
to CmdStan. Offers better performance and more features than PyStan.
Supports MCMC sampling, variational inference, and optimization
with automatic conversion to ArviZ format.
"""
def __init__(self, stan_file: str, **kwargs):
"""
Initialize CmdStanPy wrapper.
Args:
stan_file (str): Path to .stan model file
**kwargs: CmdStanModel compilation parameters
"""
def sample(self, data: dict = None, *, chains: int = 4, iter_sampling: int = 1000, **kwargs):
"""
Run MCMC sampling with CmdStanPy.
Args:
data (dict, optional): Data dictionary for Stan model
chains (int): Number of MCMC chains (default 4)
iter_sampling (int): Number of sampling iterations (default 1000)
**kwargs: Additional CmdStanPy sampling parameters
Returns:
InferenceData: ArviZ inference data object
"""
def variational(self, data: dict = None, **kwargs):
"""Run variational inference with CmdStanPy."""
def optimize(self, data: dict = None, **kwargs):
"""Run optimization with CmdStanPy."""class PyMCSamplingWrapper(SamplingWrapper):
"""
Sampling wrapper for PyMC (formerly PyMC3).
Provides unified interface for PyMC model context management,
MCMC sampling with NUTS, and automatic conversion to ArviZ.
Handles PyMC model contexts, prior predictive sampling,
posterior predictive sampling, and comprehensive diagnostics.
"""
def __init__(self, model_context, **kwargs):
"""
Initialize PyMC wrapper.
Args:
model_context: PyMC model context or model object
**kwargs: Additional PyMC configuration parameters
"""
def sample(self, *, draws: int = 1000, tune: int = 1000, chains: int = 4, **kwargs):
"""
Run MCMC sampling with PyMC.
Args:
draws (int): Number of samples to draw (default 1000)
tune (int): Number of tuning samples (default 1000)
chains (int): Number of MCMC chains (default 4)
**kwargs: Additional PyMC sampling parameters (nuts_sampler, etc.)
Returns:
InferenceData: ArviZ inference data object with all groups
"""
def sample_prior_predictive(self, samples: int = 500, **kwargs):
"""Sample from prior predictive distribution."""
def sample_posterior_predictive(self, trace, samples: int = 500, **kwargs):
"""Sample from posterior predictive distribution."""# CmdStanPy wrapper usage
wrapper = az.CmdStanPySamplingWrapper("my_model.stan")
# Prepare data
data = {
"N": len(y_obs),
"y": y_obs,
"x": x_data
}
# Run sampling with automatic conversion
idata = wrapper.sample(
data=data,
chains=4,
iter_sampling=2000,
iter_warmup=1000
)
# Data is automatically converted to InferenceData
print(f"Posterior samples: {idata.posterior.dims}")
print(f"Sample stats: {list(idata.sample_stats.data_vars)}")
# PyMC wrapper usage
import pymc as pm
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)
y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)
wrapper = az.PyMCSamplingWrapper(model)
idata = wrapper.sample(draws=1000, tune=1000, chains=4)
# Includes prior and posterior predictive samples automatically
print(f"Groups: {list(idata.groups())}")
# PyStan wrapper usage
model_code = """
data {
int<lower=0> N;
vector[N] y;
}
parameters {
real mu;
real<lower=0> sigma;
}
model {
mu ~ normal(0, 1);
sigma ~ half_normal(1);
y ~ normal(mu, sigma);
}
"""
wrapper = az.PyStanSamplingWrapper(model_code=model_code)
idata = wrapper.sample(
data={"N": len(y_obs), "y": y_obs},
num_chains=4,
num_samples=1000
)# Common configuration patterns across wrappers
config = {
"chains": 4,
"cores": 4, # Parallel chain execution
"progress_bar": True,
"return_inferencedata": True, # Default for all wrappers
}
# Framework-specific configurations
cmdstanpy_config = {
**config,
"iter_sampling": 1000,
"iter_warmup": 1000,
"adapt_delta": 0.8, # NUTS tuning parameter
"max_treedepth": 10
}
pymc_config = {
**config,
"draws": 1000,
"tune": 1000,
"target_accept": 0.8,
"nuts_sampler": "nutpie" # Alternative sampler
}
# Use with wrappers
cmdstan_wrapper = az.CmdStanPySamplingWrapper("model.stan")
idata = cmdstan_wrapper.sample(data=data, **cmdstanpy_config)
pymc_wrapper = az.PyMCSamplingWrapper(pymc_model)
idata = pymc_wrapper.sample(**pymc_config)# Compare results across frameworks easily
frameworks = {
"cmdstanpy": az.CmdStanPySamplingWrapper("model.stan"),
"pymc": az.PyMCSamplingWrapper(pymc_model),
"pystan": az.PyStanSamplingWrapper(model_code=stan_code)
}
results = {}
for name, wrapper in frameworks.items():
results[name] = wrapper.sample(data=data, chains=4)
# All results are InferenceData objects - easy comparison
comparison = az.compare(results)
print(comparison)Install with Tessl CLI
npx tessl i tessl/pypi-arviz