The Python ensemble sampling toolkit for affine-invariant MCMC
—
The State class in emcee provides a unified interface for handling walker ensemble states during MCMC sampling. It encapsulates walker positions, log probabilities, metadata blobs, and random number generator states, enabling checkpointing, state manipulation, and backward compatibility.
Container for ensemble state information with support for iteration and indexing.
class State:
def __init__(self, coords, log_prob=None, blobs=None, random_state=None,
copy: bool = False):
"""
Initialize ensemble state.
Args:
coords: Walker positions [nwalkers, ndim] or existing State object
log_prob: Log probabilities [nwalkers] (optional)
blobs: Metadata blobs (optional)
random_state: Random number generator state (optional)
copy: Whether to deep copy input data (default: False)
"""
coords: np.ndarray # Walker positions [nwalkers, ndim]
log_prob: np.ndarray # Log probabilities [nwalkers]
blobs: any # Metadata blobs
random_state: any # Random number generator stateMethods for accessing and manipulating state information.
def __len__(self):
"""
Get length of state tuple for unpacking.
Returns:
int: 3 if no blobs, 4 if blobs present
"""
def __repr__(self):
"""String representation of state."""
def __iter__(self):
"""
Iterate over state components for backward compatibility.
Yields:
coords, log_prob, random_state[, blobs]
"""
def __getitem__(self, index: int):
"""
Access state components by index.
Args:
index: Component index (0=coords, 1=log_prob, 2=random_state, 3=blobs)
Returns:
State component at given index
"""import emcee
import numpy as np
# Create state from coordinates only
coords = np.random.randn(32, 2)
state = emcee.State(coords)
print(f"Coords shape: {state.coords.shape}")
print(f"Log prob: {state.log_prob}") # None initially
print(f"Blobs: {state.blobs}") # None initially
# Create state with log probabilities
log_prob = np.random.randn(32)
state = emcee.State(coords, log_prob=log_prob)
print(f"Log prob shape: {state.log_prob.shape}")def log_prob(theta):
return -0.5 * np.sum(theta**2)
# Run sampling
sampler = emcee.EnsembleSampler(32, 2, log_prob)
pos = np.random.randn(32, 2)
final_state = sampler.run_mcmc(pos, 100)
print(f"Final state type: {type(final_state)}")
print(f"Final coords shape: {final_state.coords.shape}")
print(f"Final log_prob shape: {final_state.log_prob.shape}")
# Get last sampled state
last_state = sampler.get_last_sample()
print(f"Same as final state: {np.array_equal(final_state.coords, last_state.coords)}")# State supports tuple unpacking for backward compatibility
state = emcee.State(coords, log_prob=log_prob)
# Unpack without blobs
pos, lp, rstate = state
print(f"Unpacked coords shape: {pos.shape}")
print(f"Unpacked log_prob shape: {lp.shape}")
# With blobs
def log_prob_with_blobs(theta):
log_p = -0.5 * np.sum(theta**2)
return log_p, {"energy": np.sum(theta**2)}
sampler_blobs = emcee.EnsembleSampler(32, 2, log_prob_with_blobs)
final_state_blobs = sampler_blobs.run_mcmc(pos, 10)
# Unpack with blobs
pos, lp, rstate, blobs = final_state_blobs
print(f"Blobs type: {type(blobs)}")state = emcee.State(coords, log_prob=log_prob)
# Access by index
print(f"Index 0 (coords): {state[0].shape}")
print(f"Index 1 (log_prob): {state[1].shape}")
print(f"Index 2 (random_state): {state[2]}")
# Negative indexing
print(f"Index -1 (last element): {state[-1] is state[2]}")
# Length of state
print(f"State length: {len(state)}") # 3 without blobs, 4 with blobs# Create state with copy=True for safety
original_coords = np.random.randn(32, 2)
state_copy = emcee.State(original_coords, copy=True)
# Modify original - copied state unchanged
original_coords[0, 0] = 999
print(f"Original modified: {original_coords[0, 0]}")
print(f"Copy unchanged: {state_copy.coords[0, 0]}")
# Create state from another state
state2 = emcee.State(state_copy, copy=True)
print(f"State from state: {np.array_equal(state2.coords, state_copy.coords)}")# Save state for resuming
sampler = emcee.EnsembleSampler(32, 2, log_prob)
pos = np.random.randn(32, 2)
# Initial sampling
intermediate_state = sampler.run_mcmc(pos, 500)
print(f"Completed {sampler.iteration} steps")
# Resume from intermediate state
final_state = sampler.run_mcmc(intermediate_state, 500)
print(f"Total completed: {sampler.iteration} steps")
# State preserves random state for reproducibility
print(f"Random state preserved: {final_state.random_state is not None}")def log_prob_detailed(theta):
log_p = -0.5 * np.sum(theta**2)
# Return detailed metadata
blobs = {
"energy": np.sum(theta**2),
"grad_norm": np.linalg.norm(theta),
"param_sum": np.sum(theta)
}
return log_p, blobs
sampler = emcee.EnsembleSampler(32, 2, log_prob_detailed)
final_state = sampler.run_mcmc(pos, 100)
print(f"State has blobs: {final_state.blobs is not None}")
print(f"Blob keys: {final_state.blobs.dtype.names if final_state.blobs is not None else 'None'}")
# Access specific blob data
if final_state.blobs is not None:
energies = final_state.blobs["energy"]
print(f"Final energies: {energies[:5]}") # First 5 walkers# Create custom state for specific initialization
nwalkers, ndim = 32, 2
# Initialize walkers in specific pattern
coords = np.zeros((nwalkers, ndim))
coords[:nwalkers//2] = np.random.normal(loc=-1, scale=0.5, size=(nwalkers//2, ndim))
coords[nwalkers//2:] = np.random.normal(loc=1, scale=0.5, size=(nwalkers//2, ndim))
# Pre-compute log probabilities
log_probs = np.array([log_prob(coord) for coord in coords])
# Create initialized state
init_state = emcee.State(coords, log_prob=log_probs)
# Use in sampler
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
final_state = sampler.run_mcmc(init_state, 1000)
print(f"Started with precomputed log_probs: {init_state.log_prob is not None}")def inspect_state(state, label="State"):
"""Utility function to inspect state contents."""
print(f"\n{label} Inspection:")
print(f" Coords shape: {state.coords.shape}")
print(f" Coords range: [{np.min(state.coords):.3f}, {np.max(state.coords):.3f}]")
if state.log_prob is not None:
print(f" Log prob shape: {state.log_prob.shape}")
print(f" Log prob range: [{np.min(state.log_prob):.3f}, {np.max(state.log_prob):.3f}]")
else:
print(" Log prob: None")
print(f" Has blobs: {state.blobs is not None}")
print(f" Has random state: {state.random_state is not None}")
print(f" State length: {len(state)}")
# Inspect various states
init_state = emcee.State(np.random.randn(32, 2))
inspect_state(init_state, "Initial")
# After sampling
sampler = emcee.EnsembleSampler(32, 2, log_prob)
final_state = sampler.run_mcmc(init_state, 100)
inspect_state(final_state, "Final")from multiprocessing import Pool
def log_prob_parallel(theta):
return -0.5 * np.sum(theta**2)
# State works seamlessly with parallel processing
with Pool() as pool:
sampler = emcee.EnsembleSampler(32, 2, log_prob_parallel, pool=pool)
# Initialize with state
init_state = emcee.State(np.random.randn(32, 2))
final_state = sampler.run_mcmc(init_state, 1000)
print(f"Parallel sampling completed: {final_state.coords.shape}")Install with Tessl CLI
npx tessl i tessl/pypi-emcee