Python interface to Stan, a package for Bayesian inference
—
Tools for accessing, analyzing, and transforming MCMC samples into usable formats. The Fit class provides a dictionary-like interface for working with posterior samples and includes utilities for data export and analysis.
Access posterior samples through a dictionary-like interface.
class Fit:
"""
Stores draws from one or more chains.
A Fit instance works like a Python dictionary. A user-friendly view of draws
is available via to_frame().
Attributes:
stan_outputs: Raw Stan output for each chain
num_chains: Number of chains
param_names: Parameter names from model
constrained_param_names: All constrained parameter names
dims: Parameter dimensions
num_warmup: Number of warmup iterations
num_samples: Number of sampling iterations
num_thin: Thinning interval
save_warmup: Whether warmup samples were saved
sample_and_sampler_param_names: Names of sample and sampler parameters
"""
def __getitem__(self, param: str):
"""
Access parameter draws by name.
Args:
param: Parameter name
Returns:
numpy.ndarray: Parameter draws with shape (num_draws, num_chains)
Raises:
KeyError: If parameter name not found
"""
def __contains__(self, key: str) -> bool:
"""
Check if parameter exists in the fit.
Args:
key: Parameter name
Returns:
bool: True if parameter exists
"""
def __iter__(self):
"""
Iterate over parameter names.
Yields:
str: Parameter names in order
"""
def __len__(self) -> int:
"""
Number of parameters in the fit.
Returns:
int: Total number of parameters
"""Convert samples to user-friendly formats.
def to_frame(self):
"""
Return view of draws as a pandas DataFrame.
The DataFrame contains all parameters and diagnostic information,
flattened across all chains with draws as rows and parameters as columns.
Returns:
pandas.DataFrame: DataFrame with num_draws rows and
num_flat_params columns
Raises:
RuntimeError: If pandas is not installed
Notes:
- Requires pandas to be installed
- Includes both model parameters and sampler diagnostics
- All chains are combined into a single DataFrame
- Column names match sample_and_sampler_param_names + constrained_param_names
"""Human-readable summary of fit contents.
def __repr__(self) -> str:
"""
String representation of the Fit showing parameter summaries.
Provides a concise overview of all parameters including:
- Parameter names and dimensions
- Number of chains and draws
- Brief statistical summary
Returns:
str: Formatted summary of fit contents
"""import stan
program_code = """
parameters {
real mu;
real<lower=0> sigma;
vector[3] theta;
}
model {
mu ~ normal(0, 1);
sigma ~ exponential(1);
theta ~ normal(0, 1);
}
"""
model = stan.build(program_code)
fit = model.sample(num_chains=4, num_samples=1000)
# Access individual parameters
mu_samples = fit['mu']
sigma_samples = fit['sigma']
theta_samples = fit['theta']
print(f"mu samples shape: {mu_samples.shape}")
print(f"sigma samples shape: {sigma_samples.shape}")
print(f"theta samples shape: {theta_samples.shape}")
# Check parameter existence
print(f"'mu' in fit: {'mu' in fit}")
print(f"'nonexistent' in fit: {'nonexistent' in fit}")
# Iterate over parameters
print("All parameters:")
for param_name in fit:
print(f" {param_name}: {fit[param_name].shape}")
print(f"Total parameters: {len(fit)}")import stan
import numpy as np
import pandas as pd
program_code = """
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real alpha;
real beta;
real<lower=0> sigma;
}
model {
alpha ~ normal(0, 10);
beta ~ normal(0, 10);
sigma ~ exponential(1);
y ~ normal(alpha + beta * x, sigma);
}
"""
# Generate data
N = 50
x = np.random.normal(0, 1, N)
y = 2 + 3 * x + np.random.normal(0, 1, N)
data = {'N': N, 'x': x.tolist(), 'y': y.tolist()}
model = stan.build(program_code, data=data)
fit = model.sample(num_chains=4, num_samples=1000)
# Convert to DataFrame
df = fit.to_frame()
print(f"DataFrame shape: {df.shape}")
print(f"DataFrame columns: {list(df.columns)}")
# Basic statistics
print("\nParameter summaries:")
print(df.describe())
# Chain-specific analysis
print("\nMean by chain:")
print(df.groupby(level=1, axis=1).mean())
# Plot traces (requires matplotlib)
import matplotlib.pyplot as plt
fig, axes = plt.subplots(3, 1, figsize=(10, 8))
parameters = ['alpha', 'beta', 'sigma']
for i, param in enumerate(parameters):
for chain in range(4):
axes[i].plot(df[(param, chain)], alpha=0.7, label=f'Chain {chain}')
axes[i].set_title(f'{param} trace')
axes[i].legend()
plt.tight_layout()
plt.show()import stan
import numpy as np
program_code = """
parameters {
real mu;
real<lower=0> sigma;
}
model {
mu ~ normal(0, 1);
sigma ~ exponential(1);
}
"""
model = stan.build(program_code)
fit = model.sample(num_chains=4, num_samples=1000, num_warmup=1000)
# Access diagnostic parameters
print("Available parameters:")
for param in fit:
print(f" {param}")
# Check sampler diagnostics
if 'accept_stat__' in fit:
accept_stat = fit['accept_stat__']
print(f"\nAcceptance rate statistics:")
print(f" Mean: {np.mean(accept_stat):.3f}")
print(f" Min: {np.min(accept_stat):.3f}")
print(f" Max: {np.max(accept_stat):.3f}")
if 'treedepth__' in fit:
treedepth = fit['treedepth__']
print(f"\nTree depth statistics:")
print(f" Mean: {np.mean(treedepth):.1f}")
print(f" Max: {np.max(treedepth)}")
if 'stepsize__' in fit:
stepsize = fit['stepsize__']
print(f"\nStep size by chain:")
for chain in range(fit.num_chains):
chain_stepsize = stepsize[:, chain]
print(f" Chain {chain}: {np.mean(chain_stepsize):.4f}")
# Print fit summary
print(f"\nFit summary:")
print(fit)import stan
import numpy as np
from scipy import stats
program_code = """
parameters {
real mu;
real<lower=0> sigma;
}
model {
mu ~ normal(0, 1);
sigma ~ exponential(1);
}
"""
model = stan.build(program_code)
fit = model.sample(num_chains=4, num_samples=2000)
# Posterior analysis
mu_samples = fit['mu'].flatten() # Combine all chains
sigma_samples = fit['sigma'].flatten()
print("Posterior summaries:")
print(f"mu: mean={np.mean(mu_samples):.3f}, std={np.std(mu_samples):.3f}")
print(f"sigma: mean={np.mean(sigma_samples):.3f}, std={np.std(sigma_samples):.3f}")
# Credible intervals
mu_ci = np.percentile(mu_samples, [2.5, 97.5])
sigma_ci = np.percentile(sigma_samples, [2.5, 97.5])
print(f"\n95% Credible Intervals:")
print(f"mu: [{mu_ci[0]:.3f}, {mu_ci[1]:.3f}]")
print(f"sigma: [{sigma_ci[0]:.3f}, {sigma_ci[1]:.3f}]")
# Chain convergence (R-hat approximation)
def split_rhat(chains):
"""Simple R-hat calculation for demonstration"""
n_chains, n_draws = chains.shape[1], chains.shape[0]
# Split each chain in half
first_half = chains[:n_draws//2, :]
second_half = chains[n_draws//2:, :]
# Combine split chains
all_chains = np.concatenate([first_half, second_half], axis=1)
# Between and within chain variance
chain_means = np.mean(all_chains, axis=0)
overall_mean = np.mean(all_chains)
B = (n_draws // 2) * np.var(chain_means, ddof=1)
W = np.mean([np.var(all_chains[:, i], ddof=1) for i in range(all_chains.shape[1])])
var_est = ((n_draws // 2 - 1) / (n_draws // 2)) * W + B / (n_draws // 2)
rhat = np.sqrt(var_est / W)
return rhat
mu_rhat = split_rhat(fit['mu'])
sigma_rhat = split_rhat(fit['sigma'])
print(f"\nR-hat diagnostics:")
print(f"mu: {mu_rhat:.3f}")
print(f"sigma: {sigma_rhat:.3f}")
print("(Values < 1.1 indicate good convergence)")Install with Tessl CLI
npx tessl i tessl/pypi-pystan