Python interface to CmdStan that provides comprehensive access to the Stan compiler and all Bayesian inference algorithms.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Container for Markov Chain Monte Carlo sampling results with comprehensive diagnostics and multiple data access formats. The CmdStanMCMC class provides access to posterior draws, diagnostics, and summary statistics from NUTS-HMC sampling.
Access posterior draws in multiple formats for integration with different analysis workflows.
def draws(self, inc_warmup=False, concat_chains=False, vars=None):
"""
Get parameter draws as NumPy array.
Parameters:
- inc_warmup (bool): Include warmup draws
- concat_chains (bool): Concatenate chains into single array
- vars (list of str, optional): Specific variables to include
Returns:
np.ndarray: Draws array with shape (draws, chains, parameters) or (total_draws, parameters) if concat_chains=True
"""
def draws_pd(self, vars=None, inc_warmup=False):
"""
Get parameter draws as pandas DataFrame.
Parameters:
- vars (list of str, optional): Specific variables to include
- inc_warmup (bool): Include warmup draws
Returns:
pd.DataFrame: Draws with parameter names as columns
"""
def draws_xr(self, vars=None, inc_warmup=False):
"""
Get parameter draws as xarray Dataset.
Parameters:
- vars (list of str, optional): Specific variables to include
- inc_warmup (bool): Include warmup draws
Returns:
xr.Dataset: Draws with coordinate labels and metadata
"""Usage Examples:
# Get all draws as NumPy array
draws_array = fit.draws() # Shape: (1000, 4, 15) for 1000 draws, 4 chains, 15 parameters
# Get specific variables as DataFrame
theta_phi_df = fit.draws_pd(vars=["theta", "phi"])
# Get xarray Dataset with metadata
draws_xr = fit.draws_xr()
print(draws_xr.coords) # Shows parameter names, chain IDs, draw numbersAccess individual Stan variables by name with automatic array reshaping.
def stan_variable(self, var, inc_warmup=False):
"""
Get draws for specific Stan variable.
Parameters:
- var (str): Variable name
- inc_warmup (bool): Include warmup draws
Returns:
np.ndarray: Draws for the variable with original Stan dimensions
"""
def stan_variables(self):
"""
Get all Stan variables as dictionary.
Returns:
dict: Mapping from variable names to draw arrays
"""Usage Examples:
# Access individual variables
theta = fit.stan_variable("theta") # Returns array with Stan dimensions
mu = fit.stan_variable("mu[1]") # Access specific array element
# Get all variables
all_vars = fit.stan_variables()
for name, draws in all_vars.items():
print(f"{name}: {draws.shape}")Access MCMC diagnostic information including divergences, tree depth, and energy statistics.
def method_variables(self):
"""
Get sampling diagnostic variables.
Returns:
dict: Mapping from diagnostic names to values across chains
"""Usage Examples:
diagnostics = fit.method_variables()
# Check for sampling issues
print("Divergences per chain:", diagnostics.get("divergent__"))
print("Max treedepth hits:", diagnostics.get("treedepth__"))
print("Energy statistics:", diagnostics.get("energy__"))Generate comprehensive summary statistics including posterior means, quantiles, and convergence diagnostics.
def summary(self, percentiles=None, sig_figs=None):
"""
Compute summary statistics for all parameters.
Parameters:
- percentiles (list of float, optional): Percentiles to compute (default: [5, 50, 95])
- sig_figs (int, optional): Significant figures for output
Returns:
pd.DataFrame: Summary statistics with R-hat, ESS, and quantiles
"""Usage Examples:
# Default summary with 5%, 50%, 95% quantiles
summary = fit.summary()
print(summary)
# Custom quantiles
summary_custom = fit.summary(percentiles=[2.5, 25, 50, 75, 97.5])
# High precision output
summary_precise = fit.summary(sig_figs=10)Run comprehensive convergence diagnostics using CmdStan's built-in diagnostic tools.
def diagnose(self):
"""
Run CmdStan diagnostics on chains.
Returns:
str or None: Diagnostic output if issues found, None if chains look good
"""Usage Example:
# Check for convergence issues
diagnostic_output = fit.diagnose()
if diagnostic_output:
print("Diagnostic issues found:")
print(diagnostic_output)
else:
print("No issues detected")Access metadata and sampling configuration information.
# Chain information
fit.chains # int: Number of chains
fit.chain_ids # List[int]: Chain identifiers
fit.num_draws_warmup # int: Warmup iterations per chain
fit.num_draws_sampling # int: Sampling iterations per chain
fit.thin # int: Thinning interval
# Parameter information
fit.column_names # Tuple[str, ...]: All output column names
fit.metadata # InferenceMetadata: Run configuration and timing
# Adaptation information
fit.metric_type # str or None: Mass matrix type ("diag_e", "dense_e", "unit_e")
fit.metric # np.ndarray or None: Mass matrix values per chain
fit.step_size # np.ndarray or None: Final step sizes per chain
# Diagnostic counts
fit.divergences # np.ndarray or None: Divergent transitions per chain
fit.max_treedepths # np.ndarray or None: Max treedepth hits per chainUsage Examples:
print(f"Ran {fit.chains} chains with {fit.num_draws_sampling} samples each")
print(f"Final step sizes: {fit.step_size}")
if fit.divergences is not None and fit.divergences.sum() > 0:
print(f"Warning: {fit.divergences.sum()} total divergent transitions")Save and manage CSV output files for reproducibility and external analysis.
def save_csvfiles(self, dir=None):
"""
Save CSV output files to directory.
Parameters:
- dir (str or PathLike, optional): Target directory (default: creates timestamped directory)
Returns:
None
"""Usage Example:
# Save to specific directory
fit.save_csvfiles(dir="./mcmc_results")
# Save to timestamped directory
fit.save_csvfiles() # Creates directory like "chain_outputs_20231201_143022"# Basic convergence check
summary = fit.summary()
print("R-hat range:", summary["R_hat"].min(), "to", summary["R_hat"].max())
# Detailed diagnostics
diagnostics = fit.diagnose()
if diagnostics:
print("Sampling issues detected")
# Extract key parameters for analysis
theta = fit.stan_variable("theta")
posterior_mean = theta.mean(axis=(0, 1)) # Average across draws and chains
posterior_std = theta.std(axis=(0, 1))
print(f"Posterior mean: {posterior_mean}")
print(f"Posterior std: {posterior_std}")# For Stan model with parameters:
# parameters {
# matrix[N, K] beta;
# vector[J] alpha;
# real<lower=0> sigma;
# }
# Access full parameter arrays
beta = fit.stan_variable("beta") # Shape: (1000, 4, N, K)
alpha = fit.stan_variable("alpha") # Shape: (1000, 4, J)
sigma = fit.stan_variable("sigma") # Shape: (1000, 4)
# Work with specific elements
beta_1_2 = fit.stan_variable("beta[1,2]") # Specific matrix element
alpha_mean = alpha.mean(axis=(0, 1)) # Posterior means for each elementimport arviz as az
import matplotlib.pyplot as plt
# Convert to ArviZ InferenceData for advanced diagnostics
inference_data = az.from_cmdstanpy(fit)
# Diagnostic plots
az.plot_trace(inference_data, var_names=["theta", "sigma"])
plt.show()
# Posterior predictive analysis
az.plot_posterior(inference_data, var_names=["theta"])
plt.show()
# Effective sample size and R-hat
print(az.summary(inference_data, round_to=3))# For large models, access specific variables to save memory
important_params = ["theta", "sigma"]
subset_draws = fit.draws_pd(vars=important_params)
# Clear full draws from memory if not needed
import gc
del fit._draws_array
gc.collect()
# Use xarray for efficient slicing of large datasets
draws_xr = fit.draws_xr()
chain_1_only = draws_xr.sel(chain=0) # Select specific chain
recent_draws = draws_xr.isel(draw=slice(-100, None)) # Last 100 draws# Check sampling efficiency
summary = fit.summary()
low_ess_params = summary[summary["N_Eff"] < 100].index.tolist()
high_rhat_params = summary[summary["R_hat"] > 1.1].index.tolist()
if low_ess_params:
print(f"Low ESS parameters: {low_ess_params}")
if high_rhat_params:
print(f"High R-hat parameters: {high_rhat_params}")
# Energy diagnostics
method_vars = fit.method_variables()
energy = method_vars.get("energy__")
if energy is not None:
print(f"Energy statistics available for {energy.shape[1]} chains")Install with Tessl CLI
npx tessl i tessl/pypi-cmdstanpy