CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-arviz

Exploratory analysis of Bayesian models with comprehensive data manipulation, statistical diagnostics, and visualization capabilities

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

framework-integrations.mddocs/

Framework Integrations

Convert inference results from various probabilistic programming frameworks to ArviZ's unified InferenceData format. Supports Stan (CmdStan, PyStan, CmdStanPy), PyMC, Pyro, NumPyro, JAX, emcee, and more.

Stan Ecosystem

CmdStan and CmdStanPy

def from_cmdstan(posterior: str = None, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, **kwargs) -> InferenceData:
    """
    Convert CmdStan output files to InferenceData.
    
    Args:
        posterior (str, optional): Path to posterior samples CSV file
        posterior_predictive (str, optional): Path to posterior predictive CSV
        observed_data (dict, optional): Dictionary of observed data
        constant_data (dict, optional): Dictionary of constant/fixed data
        predictions (dict, optional): Dictionary of out-of-sample predictions
        **kwargs: Additional conversion parameters (coords, dims, etc.)
    
    Returns:
        InferenceData: Converted inference data object
    """

def from_cmdstanpy(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:
    """
    Convert CmdStanPy fit results to InferenceData.
    
    Args:
        fit: CmdStanPy fit object (CmdStanMCMC, CmdStanMLE, CmdStanVB)
        posterior_predictive (str, optional): Variable name for posterior predictive
        observed_data (dict, optional): Dictionary of observed data
        constant_data (dict, optional): Dictionary of constant data
        **kwargs: Additional conversion parameters
    
    Returns:
        InferenceData: Converted inference data object
    """

PyStan

def from_pystan(fit, *, posterior_predictive: str = None, observed_data: dict = None, constant_data: dict = None, **kwargs) -> InferenceData:
    """
    Convert PyStan fit results to InferenceData.
    
    Args:
        fit: PyStan fit object (StanFit4Model)
        posterior_predictive (str, optional): Variable name for posterior predictive
        observed_data (dict, optional): Dictionary of observed data  
        constant_data (dict, optional): Dictionary of constant data
        **kwargs: Additional conversion parameters (coords, dims, etc.)
    
    Returns:
        InferenceData: Converted inference data object
    """

Usage Examples

import arviz as az
import cmdstanpy

# CmdStanPy example
model = cmdstanpy.CmdStanModel(stan_file="model.stan")
fit = model.sample(data=data_dict)
idata = az.from_cmdstanpy(fit, observed_data={"y": y_obs})

# CmdStan CSV files
idata = az.from_cmdstan(
    posterior="output.csv",
    posterior_predictive="predictions.csv",
    observed_data={"y": y_obs}
)

# PyStan example (legacy)
import pystan
model = pystan.StanModel(file="model.stan")
fit = model.sampling(data=data_dict)
idata = az.from_pystan(fit, observed_data={"y": y_obs})

PyTorch/JAX Ecosystem

Pyro

def from_pyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:
    """
    Convert Pyro MCMC results to InferenceData.
    
    Args:
        posterior (dict): Dictionary of posterior samples from Pyro MCMC
        prior (dict, optional): Dictionary of prior samples
        posterior_predictive (dict, optional): Dictionary of posterior predictive samples
        observed_data (dict, optional): Dictionary of observed data
        **kwargs: Additional conversion parameters (coords, dims, etc.)
    
    Returns:
        InferenceData: Converted inference data object
    """

NumPyro

def from_numpyro(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, observed_data: dict = None, **kwargs) -> InferenceData:
    """
    Convert NumPyro MCMC results to InferenceData.
    
    Args:
        posterior (dict): Dictionary of posterior samples from NumPyro MCMC
        prior (dict, optional): Dictionary of prior samples
        posterior_predictive (dict, optional): Dictionary of posterior predictive samples
        observed_data (dict, optional): Dictionary of observed data
        **kwargs: Additional conversion parameters (coords, dims, etc.)
    
    Returns:
        InferenceData: Converted inference data object
    """

Usage Examples

import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

# NumPyro example
def model(y):
    mu = numpyro.sample("mu", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

# Run MCMC
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=2000)
mcmc.run(jax.random.PRNGKey(0), y=data)

# Convert to ArviZ
idata = az.from_numpyro(
    mcmc,
    observed_data={"y": data},
    coords={"obs": range(len(data))}
)

# Pyro example (similar pattern)
import pyro
import torch

# After running Pyro MCMC
posterior_samples = mcmc.get_samples()
idata = az.from_pyro(
    posterior_samples,
    observed_data={"y": data}
)

Other Frameworks

emcee

def from_emcee(sampler, *, var_names: list = None, slices: slice = None, **kwargs) -> InferenceData:
    """
    Convert emcee ensemble sampler results to InferenceData.
    
    Args:
        sampler: emcee EnsembleSampler object
        var_names (list, optional): Variable names for parameters
        slices (slice, optional): Slice object for chain selection
        **kwargs: Additional conversion parameters (coords, dims, etc.)
    
    Returns:
        InferenceData: Converted inference data object
    """

PyJAGS

def from_pyjags(fit, *, var_names: list = None, **kwargs) -> InferenceData:
    """
    Convert PyJAGS fit results to InferenceData.
    
    Args:
        fit: PyJAGS fit object
        var_names (list, optional): Variable names to extract
        **kwargs: Additional conversion parameters
    
    Returns:
        InferenceData: Converted inference data object
    """

Bean Machine

def from_beanmachine(beanmachine_model, *, observed_data: dict = None, **kwargs) -> InferenceData:
    """
    Convert Bean Machine model results to InferenceData.
    
    Args:
        beanmachine_model: Bean Machine model object with samples
        observed_data (dict, optional): Dictionary of observed data
        **kwargs: Additional conversion parameters
    
    Returns:
        InferenceData: Converted inference data object
    """

Usage Examples

import emcee
import numpy as np

# emcee example
def log_prob(theta):
    return -0.5 * np.sum(theta**2)

# Run emcee sampler
nwalkers, ndim = 32, 5
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
sampler.run_mcmc(np.random.randn(nwalkers, ndim), 1000)

# Convert to ArviZ
idata = az.from_emcee(
    sampler,
    var_names=["param_1", "param_2", "param_3", "param_4", "param_5"]
)

Generic Conversions

Dictionary-based Conversion

def from_dict(posterior: dict, *, prior: dict = None, posterior_predictive: dict = None, sample_stats: dict = None, observed_data: dict = None, constant_data: dict = None, predictions: dict = None, log_likelihood: dict = None, **kwargs) -> InferenceData:
    """
    Convert dictionary of arrays to InferenceData.
    
    Args:
        posterior (dict): Dictionary of posterior samples (var_name -> array)
        prior (dict, optional): Dictionary of prior samples
        posterior_predictive (dict, optional): Dictionary of posterior predictive samples
        sample_stats (dict, optional): Dictionary of MCMC diagnostics
        observed_data (dict, optional): Dictionary of observed data
        constant_data (dict, optional): Dictionary of constant data
        predictions (dict, optional): Dictionary of out-of-sample predictions
        log_likelihood (dict, optional): Dictionary of log likelihood values
        **kwargs: Additional conversion parameters (coords, dims, etc.)
    
    Returns:
        InferenceData: Converted inference data object
    """

PyTree Conversion

def from_pytree(posterior, *, prior = None, posterior_predictive = None, **kwargs) -> InferenceData:
    """
    Convert pytree structure to InferenceData.
    
    Args:
        posterior: Pytree structure with posterior samples (JAX, PyTorch, etc.)
        prior (optional): Pytree structure with prior samples
        posterior_predictive (optional): Pytree structure with posterior predictive samples
        **kwargs: Additional conversion parameters
    
    Returns:
        InferenceData: Converted inference data object
    """

Usage Examples

# Dictionary conversion
posterior_dict = {
    "mu": np.random.normal(0, 1, (4, 1000)),      # 4 chains, 1000 draws
    "sigma": np.random.lognormal(0, 0.5, (4, 1000))
}

sample_stats_dict = {
    "diverging": np.random.binomial(1, 0.01, (4, 1000)),
    "energy": np.random.normal(0, 1, (4, 1000))
}

idata = az.from_dict(
    posterior=posterior_dict,
    sample_stats=sample_stats_dict,
    observed_data={"y": y_observed},
    coords={"chain": range(4), "draw": range(1000)}
)

# PyTree conversion (JAX example)
import jax.numpy as jnp

pytree_posterior = {
    "mu": jnp.array(np.random.normal(0, 1, (4, 1000))),
    "nested": {
        "sigma": jnp.array(np.random.lognormal(0, 0.5, (4, 1000)))
    }
}

idata = az.from_pytree(pytree_posterior, coords={"chain": range(4)})

Sampling Wrappers

ArviZ provides sampling wrapper classes for consistent interfaces across frameworks:

class SamplingWrapper:
    """Base class for sampling wrappers."""
    
class PyStanSamplingWrapper(SamplingWrapper):
    """Sampling wrapper for PyStan 3.x."""
    
class PyStan2SamplingWrapper(SamplingWrapper):
    """Sampling wrapper for PyStan 2.x."""
    
class CmdStanPySamplingWrapper(SamplingWrapper):
    """Sampling wrapper for CmdStanPy."""
    
class PyMCSamplingWrapper(SamplingWrapper):
    """Sampling wrapper for PyMC."""

Conversion Best Practices

Coordinate and Dimension Specifications

# Specify coordinates for better data organization
coords = {
    "school": ["A", "B", "C", "D", "E", "F", "G", "H"],
    "obs": range(len(observations))
}

# Specify dimensions for proper array broadcasting
dims = {
    "theta": ["school"],
    "y": ["obs"]
}

idata = az.from_dict(
    posterior=posterior_dict,
    observed_data=observed_dict,
    coords=coords,
    dims=dims
)

Handling Multiple Data Groups

# Complete data conversion with all groups
idata = az.from_dict(
    posterior=posterior_samples,           # Required
    prior=prior_samples,                  # Optional
    posterior_predictive=pp_samples,      # Optional
    sample_stats=diagnostics,             # Optional (divergences, energy, etc.)
    observed_data={"y": y_obs},          # Optional but recommended
    constant_data={"N": len(y_obs)},     # Optional
    predictions=out_of_sample_preds,      # Optional
    log_likelihood=ll_values,             # Optional (for model comparison)
    coords=coords,
    dims=dims
)

Framework-Specific Tips

  • Stan: Always include observed_data for posterior predictive checks
  • Pyro/NumPyro: Use coords and dims for multi-dimensional parameters
  • emcee: Provide meaningful var_names for parameter identification
  • Custom frameworks: Use from_dict() with proper coordinate specifications

Sampling Wrappers

ArviZ provides sampling wrapper classes that standardize the interface across different probabilistic programming frameworks for consistent model fitting and data conversion.

Base Wrapper Class

class SamplingWrapper:
    """
    Base class for probabilistic programming framework sampling wrappers.
    
    Provides a unified interface for model compilation, sampling,
    and automatic conversion to ArviZ InferenceData format across
    different Bayesian inference libraries.
    
    This abstract base class defines the common interface that all
    framework-specific wrappers should implement.
    """
    
    def __init__(self, model, **kwargs):
        """Initialize sampling wrapper with model."""
        
    def sample(self, **sample_kwargs):
        """Run MCMC sampling and return InferenceData."""
        
    def compile_model(self, **compile_kwargs):
        """Compile model for sampling (if required by framework)."""
        
    def to_inference_data(self, **conversion_kwargs):
        """Convert sampling results to InferenceData format."""

Stan Ecosystem Wrappers

class PyStanSamplingWrapper(SamplingWrapper):
    """
    Sampling wrapper for PyStan 3.x (current version).
    
    Provides unified interface for PyStan model compilation,
    MCMC sampling, and automatic conversion to InferenceData.
    
    Handles Stan model compilation, data preparation, sampling
    configuration, and result extraction with proper error handling.
    """
    
    def __init__(self, model_code: str = None, model_file: str = None, **kwargs):
        """
        Initialize PyStan wrapper.
        
        Args:
            model_code (str, optional): Stan model code as string
            model_file (str, optional): Path to .stan model file
            **kwargs: Additional PyStan compilation parameters
        """
    
    def sample(self, data: dict, *, num_chains: int = 4, num_samples: int = 1000, **kwargs):
        """
        Run MCMC sampling with PyStan.
        
        Args:
            data (dict): Data dictionary for Stan model
            num_chains (int): Number of MCMC chains (default 4)
            num_samples (int): Number of samples per chain (default 1000)
            **kwargs: Additional sampling parameters
        
        Returns:
            InferenceData: ArviZ inference data object
        """

class PyStan2SamplingWrapper(SamplingWrapper):
    """
    Sampling wrapper for PyStan 2.x (legacy version).
    
    Maintains compatibility with older PyStan 2.x installations
    while providing the same unified sampling interface.
    
    Note: PyStan 2.x is legacy. Consider upgrading to PyStan 3.x or CmdStanPy.
    """
    
    def __init__(self, model_code: str = None, model_file: str = None, **kwargs):
        """Initialize PyStan 2.x wrapper."""
    
    def sample(self, data: dict = None, **kwargs):
        """Run MCMC sampling with PyStan 2.x."""

class CmdStanPySamplingWrapper(SamplingWrapper):
    """
    Sampling wrapper for CmdStanPy (recommended Stan interface).
    
    Provides interface for CmdStanPy, the official Python interface
    to CmdStan. Offers better performance and more features than PyStan.
    
    Supports MCMC sampling, variational inference, and optimization
    with automatic conversion to ArviZ format.
    """
    
    def __init__(self, stan_file: str, **kwargs):
        """
        Initialize CmdStanPy wrapper.
        
        Args:
            stan_file (str): Path to .stan model file
            **kwargs: CmdStanModel compilation parameters
        """
    
    def sample(self, data: dict = None, *, chains: int = 4, iter_sampling: int = 1000, **kwargs):
        """
        Run MCMC sampling with CmdStanPy.
        
        Args:
            data (dict, optional): Data dictionary for Stan model
            chains (int): Number of MCMC chains (default 4)
            iter_sampling (int): Number of sampling iterations (default 1000)
            **kwargs: Additional CmdStanPy sampling parameters
        
        Returns:
            InferenceData: ArviZ inference data object
        """
    
    def variational(self, data: dict = None, **kwargs):
        """Run variational inference with CmdStanPy."""
    
    def optimize(self, data: dict = None, **kwargs):
        """Run optimization with CmdStanPy."""

PyMC Wrapper

class PyMCSamplingWrapper(SamplingWrapper):
    """
    Sampling wrapper for PyMC (formerly PyMC3).
    
    Provides unified interface for PyMC model context management,
    MCMC sampling with NUTS, and automatic conversion to ArviZ.
    
    Handles PyMC model contexts, prior predictive sampling,
    posterior predictive sampling, and comprehensive diagnostics.
    """
    
    def __init__(self, model_context, **kwargs):
        """
        Initialize PyMC wrapper.
        
        Args:
            model_context: PyMC model context or model object
            **kwargs: Additional PyMC configuration parameters
        """
    
    def sample(self, *, draws: int = 1000, tune: int = 1000, chains: int = 4, **kwargs):
        """
        Run MCMC sampling with PyMC.
        
        Args:
            draws (int): Number of samples to draw (default 1000)
            tune (int): Number of tuning samples (default 1000)
            chains (int): Number of MCMC chains (default 4)
            **kwargs: Additional PyMC sampling parameters (nuts_sampler, etc.)
        
        Returns:
            InferenceData: ArviZ inference data object with all groups
        """
    
    def sample_prior_predictive(self, samples: int = 500, **kwargs):
        """Sample from prior predictive distribution."""
    
    def sample_posterior_predictive(self, trace, samples: int = 500, **kwargs):
        """Sample from posterior predictive distribution."""

Usage Examples

# CmdStanPy wrapper usage
wrapper = az.CmdStanPySamplingWrapper("my_model.stan")

# Prepare data
data = {
    "N": len(y_obs),
    "y": y_obs,
    "x": x_data
}

# Run sampling with automatic conversion
idata = wrapper.sample(
    data=data,
    chains=4,
    iter_sampling=2000,
    iter_warmup=1000
)

# Data is automatically converted to InferenceData
print(f"Posterior samples: {idata.posterior.dims}")
print(f"Sample stats: {list(idata.sample_stats.data_vars)}")

# PyMC wrapper usage
import pymc as pm

with pm.Model() as model:
    mu = pm.Normal("mu", mu=0, sigma=1)
    sigma = pm.HalfNormal("sigma", sigma=1)
    y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)

wrapper = az.PyMCSamplingWrapper(model)
idata = wrapper.sample(draws=1000, tune=1000, chains=4)

# Includes prior and posterior predictive samples automatically
print(f"Groups: {list(idata.groups())}")

# PyStan wrapper usage
model_code = """
data {
    int<lower=0> N;
    vector[N] y;
}
parameters {
    real mu;
    real<lower=0> sigma;
}
model {
    mu ~ normal(0, 1);
    sigma ~ half_normal(1);
    y ~ normal(mu, sigma);
}
"""

wrapper = az.PyStanSamplingWrapper(model_code=model_code)
idata = wrapper.sample(
    data={"N": len(y_obs), "y": y_obs},
    num_chains=4,
    num_samples=1000
)

Wrapper Configuration

# Common configuration patterns across wrappers
config = {
    "chains": 4,
    "cores": 4,  # Parallel chain execution
    "progress_bar": True,
    "return_inferencedata": True,  # Default for all wrappers
}

# Framework-specific configurations
cmdstanpy_config = {
    **config,
    "iter_sampling": 1000,
    "iter_warmup": 1000,
    "adapt_delta": 0.8,  # NUTS tuning parameter
    "max_treedepth": 10
}

pymc_config = {
    **config,
    "draws": 1000,
    "tune": 1000,
    "target_accept": 0.8,
    "nuts_sampler": "nutpie"  # Alternative sampler
}

# Use with wrappers
cmdstan_wrapper = az.CmdStanPySamplingWrapper("model.stan")
idata = cmdstan_wrapper.sample(data=data, **cmdstanpy_config)

pymc_wrapper = az.PyMCSamplingWrapper(pymc_model)
idata = pymc_wrapper.sample(**pymc_config)

Wrapper Benefits

  1. Unified Interface: Same API across different frameworks
  2. Automatic Conversion: Results always returned as InferenceData
  3. Error Handling: Consistent error messages and troubleshooting
  4. Best Practices: Built-in recommendations for sampling parameters
  5. Extensibility: Easy to add support for new frameworks
# Compare results across frameworks easily
frameworks = {
    "cmdstanpy": az.CmdStanPySamplingWrapper("model.stan"),
    "pymc": az.PyMCSamplingWrapper(pymc_model),
    "pystan": az.PyStanSamplingWrapper(model_code=stan_code)
}

results = {}
for name, wrapper in frameworks.items():
    results[name] = wrapper.sample(data=data, chains=4)

# All results are InferenceData objects - easy comparison
comparison = az.compare(results)
print(comparison)

Install with Tessl CLI

npx tessl i tessl/pypi-arviz

docs

configuration-management.md

data-operations.md

framework-integrations.md

index.md

performance-utilities.md

statistical-analysis.md

visualization-plotting.md

tile.json