Exploratory analysis of Bayesian models with comprehensive data manipulation, statistical diagnostics, and visualization capabilities
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Comprehensive plotting functions for Bayesian analysis including diagnostic plots, distribution visualizations, model comparison plots, and posterior predictive checks. Support for multiple backends (Matplotlib, Bokeh, Plotly) with publication-quality output.
def plot_trace(data: InferenceData, *, var_names: list = None, coords: dict = None, divergences: str = "auto", figsize: tuple = None, **kwargs):
"""
Plot MCMC trace plots to assess chain mixing and convergence.
Args:
data (InferenceData): MCMC inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications for multidimensional variables
divergences (str): How to display divergences ("auto", "bottom", "top", False)
figsize (tuple, optional): Figure size (width, height)
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure depending on backend
"""
def plot_rank(data: InferenceData, *, var_names: list = None, coords: dict = None, ref_line: bool = True, **kwargs):
"""
Plot rank plots for MCMC diagnostics and convergence assessment.
Args:
data (InferenceData): MCMC inference data with multiple chains
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
ref_line (bool): Whether to show reference line for uniform distribution
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_autocorr(data: InferenceData, *, var_names: list = None, coords: dict = None, max_lag: int = 100, **kwargs):
"""
Plot autocorrelation function for MCMC chains.
Args:
data (InferenceData): MCMC inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
max_lag (int): Maximum lag to display (default 100)
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_ess(data: InferenceData, *, var_names: list = None, coords: dict = None, kind: str = "local", **kwargs):
"""
Plot effective sample size diagnostics.
Args:
data (InferenceData): MCMC inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
kind (str): Type of ESS plot ("local", "quantile", "evolution")
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_mcse(data: InferenceData, *, var_names: list = None, coords: dict = None, kind: str = "local", **kwargs):
"""
Plot Monte Carlo standard error diagnostics.
Args:
data (InferenceData): MCMC inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
kind (str): Type of MCSE plot ("local", "quantile", "evolution")
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_bpv(data: InferenceData, *, kind: str = None, coords: dict = None, **kwargs):
"""
Plot Bayesian p-value for model checking.
Args:
data (InferenceData): Inference data with observed data and posterior predictive
kind (str, optional): Plot type ("t_stat", "u_value", "scatter")
coords (dict, optional): Coordinate specifications
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_bf(bf, *, labels: list = None, ref_val: float = 0, **kwargs):
"""
Plot Bayes factor results.
Args:
bf: Bayes factor values or results dictionary
labels (list, optional): Labels for different comparisons
ref_val (float): Reference value for hypothesis (default 0)
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""def plot_energy(data: InferenceData, *, figsize: tuple = None, **kwargs):
"""
Plot energy diagnostics for Hamiltonian Monte Carlo.
Args:
data (InferenceData): MCMC inference data with energy information
figsize (tuple, optional): Figure size
**kwargs: Additional plotting parameters
Returns:
matplotlib axes showing energy distribution and transitions
"""import arviz as az
# Load example data
idata = az.load_arviz_data("centered_eight")
# Basic trace plot
az.plot_trace(idata, var_names=["mu", "tau"])
# Rank plot for convergence diagnostics
az.plot_rank(idata, var_names=["mu", "tau"])
# Autocorrelation plot
az.plot_autocorr(idata, var_names=["mu"], max_lag=50)
# ESS evolution plot
az.plot_ess(idata, kind="evolution")
# Energy diagnostics (if HMC/NUTS data available)
az.plot_energy(idata)def plot_posterior(data: InferenceData, *, var_names: list = None, coords: dict = None, figsize: tuple = None, kind: str = "kde", **kwargs):
"""
Plot posterior distributions with summary statistics.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
figsize (tuple, optional): Figure size
kind (str): Plot type ("kde", "hist")
**kwargs: Additional plotting parameters (hdi_prob, point_estimate, etc.)
Returns:
matplotlib axes or bokeh figure
"""
def plot_density(data: InferenceData, *, var_names: list = None, coords: dict = None, shade: float = 0.0, **kwargs):
"""
Plot kernel density estimation for continuous variables.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
shade (float): Shading for density curves (0.0 to 1.0)
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_dist(values, *, kind: str = "kde", label: str = None, color: str = None, **kwargs):
"""
Plot generic distribution for array-like data.
Args:
values: Array-like data to plot
kind (str): Plot type ("kde", "hist", "ecdf")
label (str, optional): Label for legend
color (str, optional): Color specification
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_kde(values, *, label: str = None, bw: str = "default", adaptive: bool = False, **kwargs):
"""
Plot kernel density estimation for raw values.
Args:
values: Array-like data for KDE
label (str, optional): Label for legend
bw (str): Bandwidth selection ("default", "scott", "silverman")
adaptive (bool): Whether to use adaptive bandwidth
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""def plot_violin(data: InferenceData, *, var_names: list = None, coords: dict = None, quartiles: bool = True, **kwargs):
"""
Plot violin plots for posterior distributions.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
quartiles (bool): Whether to show quartile lines
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_dist_comparison(comp_dict: dict, *, kind: str = "kde", figsize: tuple = None, **kwargs):
"""
Compare distributions from multiple sources.
Args:
comp_dict (dict): Dictionary of label -> data mappings
kind (str): Plot type ("kde", "hist", "ecdf")
figsize (tuple, optional): Figure size
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""# Posterior distribution plots
az.plot_posterior(idata, var_names=["mu", "tau"], hdi_prob=0.89)
# Density plots with shading
az.plot_density(idata, var_names=["mu"], shade=0.1)
# Violin plots
az.plot_violin(idata, var_names=["theta"])
# Compare posterior vs prior
comparison = {
"posterior": idata.posterior["mu"].values.flatten(),
"prior": idata.prior["mu"].values.flatten()
}
az.plot_dist_comparison(comparison, kind="kde")def plot_forest(data: InferenceData, *, var_names: list = None, coords: dict = None, combined: bool = False, hdi_prob: float = 0.94, **kwargs):
"""
Plot forest plots for parameter estimates with credible intervals.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
combined (bool): Whether to combine chains
hdi_prob (float): Probability for HDI intervals
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_compare(comp_df: pd.DataFrame, *, insample_dev: bool = True, plot_ic_diff: bool = True, figsize: tuple = None, **kwargs):
"""
Plot model comparison results from az.compare().
Args:
comp_df (pd.DataFrame): Model comparison dataframe from az.compare()
insample_dev (bool): Whether to plot in-sample deviance
plot_ic_diff (bool): Whether to plot IC differences
figsize (tuple, optional): Figure size
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_elpd(comp_df: pd.DataFrame, *, xlabels: bool = False, figsize: tuple = None, **kwargs):
"""
Plot expected log pointwise predictive density comparison.
Args:
comp_df (pd.DataFrame): Model comparison dataframe
xlabels (bool): Whether to show x-axis labels
figsize (tuple, optional): Figure size
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_khat(khat_data, *, annotate: bool = True, threshold: float = 0.7, figsize: tuple = None, **kwargs):
"""
Plot Pareto k-hat diagnostic values from LOO-CV.
Args:
khat_data: Pareto k values (from az.loo())
annotate (bool): Whether to annotate problematic points
threshold (float): Threshold for problematic k values
figsize (tuple, optional): Figure size
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""# Forest plot with credible intervals
az.plot_forest(idata, var_names=["mu", "tau"], hdi_prob=0.89)
# Model comparison plot
model_dict = {"model1": idata1, "model2": idata2}
comp_df = az.compare(model_dict)
az.plot_compare(comp_df)
# LOO diagnostics
loo_result = az.loo(idata)
az.plot_khat(loo_result.pareto_k, threshold=0.7)def plot_pair(data: InferenceData, *, var_names: list = None, coords: dict = None, kind: str = "scatter", divergences: bool = False, **kwargs):
"""
Plot pairwise relationships between variables.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to include in pairs plot
coords (dict, optional): Coordinate specifications
kind (str): Plot type ("scatter", "kde", "hexbin")
divergences (bool): Whether to highlight divergent transitions
**kwargs: Additional plotting parameters
Returns:
matplotlib axes grid
"""
def plot_parallel(data: InferenceData, *, var_names: list = None, coords: dict = None, norm_method: str = "normal", **kwargs):
"""
Plot parallel coordinate plot for multidimensional data.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
norm_method (str): Normalization method ("normal", "minmax", "rank")
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""# Pairwise scatter plot
az.plot_pair(idata, var_names=["mu", "tau", "theta"], kind="scatter")
# Parallel coordinates plot
az.plot_parallel(idata, var_names=["mu", "tau"], norm_method="minmax")def plot_ppc(data: InferenceData, *, var_names: list = None, coords: dict = None, kind: str = "kde", num_pp_samples: int = 100, **kwargs):
"""
Plot posterior predictive checks comparing observed data to predictions.
Args:
data (InferenceData): Inference data with posterior_predictive and observed_data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
kind (str): Plot type ("kde", "cumulative", "scatter")
num_pp_samples (int): Number of posterior predictive samples to show
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_loo_pit(data: InferenceData, *, y: str = None, y_hat: str = None, ecdf: bool = False, **kwargs):
"""
Plot Leave-One-Out Probability Integral Transform for model checking.
Args:
data (InferenceData): Inference data
y (str, optional): Observed data variable name
y_hat (str, optional): Posterior predictive variable name
ecdf (bool): Whether to plot empirical CDF
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_bpv(data: InferenceData, *, kind: str = "p_value", reference: float = None, **kwargs):
"""
Plot Bayesian p-value for posterior predictive checks.
Args:
data (InferenceData): Inference data
kind (str): Plot type ("p_value", "u_value")
reference (float, optional): Reference value for comparison
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""# Posterior predictive check
az.plot_ppc(idata, var_names=["y"], kind="kde", num_pp_samples=50)
# LOO-PIT diagnostic
az.plot_loo_pit(idata, y="y_obs", y_hat="y_pred")
# Bayesian p-value
az.plot_bpv(idata, kind="p_value")def plot_hdi(data, *, hdi_prob: float = 0.94, color: str = "C0", circular: bool = False, **kwargs):
"""
Plot highest density interval regions.
Args:
data: Posterior samples or InferenceData
hdi_prob (float): Probability for HDI region
color (str): Color specification
circular (bool): Whether data is circular
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_dot(data: InferenceData, *, var_names: list = None, coords: dict = None, **kwargs):
"""
Plot dot plots for discrete variables.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_ecdf(data: InferenceData, *, var_names: list = None, coords: dict = None, rug: bool = False, **kwargs):
"""
Plot empirical cumulative distribution function.
Args:
data (InferenceData): Inference data
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
rug (bool): Whether to show rug plot
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_ts(data: InferenceData, *, var_names: list = None, coords: dict = None, **kwargs):
"""
Plot time series data.
Args:
data (InferenceData): Inference data with time dimension
var_names (list, optional): Variables to plot
coords (dict, optional): Coordinate specifications
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""def plot_lm(y: np.ndarray, x: np.ndarray, *, y_hat: np.ndarray = None, y_mean: bool = False, **kwargs):
"""
Plot linear model diagnostics and predictions.
Args:
y (np.ndarray): Observed response values
x (np.ndarray): Predictor values
y_hat (np.ndarray, optional): Predicted values
y_mean (bool): Whether to show mean predictions
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_separation(y: np.ndarray, y_hat: np.ndarray, *, threshold: float = 0.5, **kwargs):
"""
Plot separation plot for binary classification models.
Args:
y (np.ndarray): True binary outcomes
y_hat (np.ndarray): Predicted probabilities
threshold (float): Classification threshold
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""
def plot_bf(bf_data, *, ref_val: float = 1.0, labels: list = None, **kwargs):
"""
Plot Bayes factor comparison.
Args:
bf_data: Bayes factor values
ref_val (float): Reference value for comparison
labels (list, optional): Labels for comparison
**kwargs: Additional plotting parameters
Returns:
matplotlib axes or bokeh figure
"""# HDI region plot
az.plot_hdi(idata.posterior["mu"], hdi_prob=0.89)
# ECDF plot
az.plot_ecdf(idata, var_names=["mu"], rug=True)
# Linear model plot
az.plot_lm(y_observed, x_data, y_hat=predictions)
# Separation plot for classification
az.plot_separation(binary_outcomes, predicted_probs)ArviZ provides specialized functions for Bokeh backend integration, enabling interactive plots in Jupyter environments and web applications.
def to_cds(data, *, var_names: list = None, coords: dict = None, **kwargs):
"""
Convert ArviZ data to Bokeh ColumnDataSource format.
Transforms InferenceData or array-like data into Bokeh's
ColumnDataSource format for interactive plotting.
Args:
data: ArviZ InferenceData or array-like data
var_names (list, optional): Variables to include in conversion
coords (dict, optional): Coordinate specifications
**kwargs: Additional conversion parameters
Returns:
bokeh.models.ColumnDataSource: Bokeh-compatible data source
"""
def output_notebook(*args, **kwargs):
"""
Configure Bokeh output for Jupyter notebook display.
Sets up Bokeh to display interactive plots inline in
Jupyter notebooks with proper widget support.
Args:
*args: Arguments passed to bokeh.io.output_notebook
**kwargs: Keyword arguments for notebook configuration
"""
def output_file(*args, **kwargs):
"""
Configure Bokeh output for HTML file generation.
Sets up Bokeh to save interactive plots as standalone
HTML files that can be shared or embedded.
Args:
*args: Arguments passed to bokeh.io.output_file
**kwargs: Keyword arguments for file output configuration
"""
def ColumnDataSource(*args, **kwargs):
"""
Create Bokeh ColumnDataSource for plotting.
Wrapper around Bokeh's ColumnDataSource that integrates
with ArviZ data structures and coordinate systems.
Args:
*args: Arguments passed to bokeh.models.ColumnDataSource
**kwargs: Keyword arguments for data source creation
Returns:
bokeh.models.ColumnDataSource: Data source for Bokeh plots
"""def create_layout(ax, force_layout: bool = False):
"""
Create Bokeh layout from plot axes.
Converts Bokeh plot figures into proper layout structures
for multi-panel displays and dashboard creation.
Args:
ax: Bokeh figure or list of figures
force_layout (bool): Whether to force specific layout structure
Returns:
bokeh.layouts Layout object (row, column, or gridplot)
"""
def show_layout(ax, show: bool = True, force_layout: bool = False):
"""
Display Bokeh layout with proper formatting.
Handles the display of single figures or complex layouts
with appropriate sizing and interactive features.
Args:
ax: Bokeh figure or layout to display
show (bool): Whether to immediately show the plot
force_layout (bool): Whether to force layout structure
Returns:
Displayed Bokeh plot or layout
"""# Convert data for Bokeh plotting
cds = az.to_cds(idata, var_names=["mu", "tau"])
print(f"Data columns: {cds.column_names}")
# Set up notebook output
az.output_notebook()
# Create interactive plots with Bokeh backend
with az.rc_context({"plot.backend": "bokeh"}):
# Trace plot with interactive features
trace_fig = az.plot_trace(idata, var_names=["mu"])
# Posterior plot with hover tooltips
post_fig = az.plot_posterior(idata, var_names=["tau"])
# Create custom layout
layout = az.create_layout([trace_fig, post_fig])
az.show_layout(layout)
# Save interactive plot to file
az.output_file("interactive_analysis.html")
with az.rc_context({"plot.backend": "bokeh"}):
az.plot_pair(idata, var_names=["mu", "tau"])
# Create custom ColumnDataSource for advanced plotting
import bokeh.plotting as bp
custom_cds = az.ColumnDataSource({
'x': idata.posterior["mu"].values.flatten(),
'y': idata.posterior["tau"].values.flatten()
})
# Use with native Bokeh plotting
p = bp.figure(title="Custom Scatter Plot")
p.circle('x', 'y', source=custom_cds, alpha=0.6)
az.show_layout(p)# Combine ArviZ and native Bokeh functionality
from bokeh.models import HoverTool
from bokeh.plotting import figure
# Create enhanced interactive plot
def create_enhanced_posterior_plot(idata, var_name="mu"):
# Convert to Bokeh data source
cds = az.to_cds(idata, var_names=[var_name])
# Create figure with tools
p = figure(
title=f"Posterior Distribution: {var_name}",
tools="pan,wheel_zoom,box_zoom,reset,save"
)
# Add hover tool
hover = HoverTool(tooltips=[
("Chain", "@chain"),
("Draw", "@draw"),
("Value", f"@{var_name}")
])
p.add_tools(hover)
# Plot data
p.circle(f'{var_name}_x', f'{var_name}_y', source=cds, alpha=0.6)
return p
# Use enhanced plotting
enhanced_plot = create_enhanced_posterior_plot(idata, "mu")
az.show_layout(enhanced_plot)ArviZ supports multiple plotting backends for different use cases:
# Set backend for current session
az.rcParams["plot.backend"] = "matplotlib" # Default
az.rcParams["plot.backend"] = "bokeh" # Interactive
az.rcParams["plot.backend"] = "plotly" # Web-based interactive
# Use context manager for temporary backend changes
with az.rc_context({"plot.backend": "bokeh"}):
az.plot_trace(idata) # Uses Bokeh backendCommon plotting parameters available across functions:
figsize: Figure size (width, height)ax: Matplotlib axes to plot onbackend_kwargs: Backend-specific parametersshow: Whether to display plot immediatelyInstall with Tessl CLI
npx tessl i tessl/pypi-arviz