CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-emcee

The Python ensemble sampling toolkit for affine-invariant MCMC

Pending
Overview
Eval results
Files

state.mddocs/

State Management

The State class in emcee provides a unified interface for handling walker ensemble states during MCMC sampling. It encapsulates walker positions, log probabilities, metadata blobs, and random number generator states, enabling checkpointing, state manipulation, and backward compatibility.

Capabilities

State Class

Container for ensemble state information with support for iteration and indexing.

class State:
    def __init__(self, coords, log_prob=None, blobs=None, random_state=None, 
                 copy: bool = False):
        """
        Initialize ensemble state.
        
        Args:
            coords: Walker positions [nwalkers, ndim] or existing State object
            log_prob: Log probabilities [nwalkers] (optional)
            blobs: Metadata blobs (optional)
            random_state: Random number generator state (optional)
            copy: Whether to deep copy input data (default: False)
        """
    
    coords: np.ndarray  # Walker positions [nwalkers, ndim]
    log_prob: np.ndarray  # Log probabilities [nwalkers] 
    blobs: any  # Metadata blobs
    random_state: any  # Random number generator state

State Properties and Methods

Methods for accessing and manipulating state information.

def __len__(self):
    """
    Get length of state tuple for unpacking.
    
    Returns:
        int: 3 if no blobs, 4 if blobs present
    """

def __repr__(self):
    """String representation of state."""

def __iter__(self):
    """
    Iterate over state components for backward compatibility.
    
    Yields:
        coords, log_prob, random_state[, blobs]
    """

def __getitem__(self, index: int):
    """
    Access state components by index.
    
    Args:
        index: Component index (0=coords, 1=log_prob, 2=random_state, 3=blobs)
        
    Returns:
        State component at given index
    """

Usage Examples

Creating State Objects

import emcee
import numpy as np

# Create state from coordinates only
coords = np.random.randn(32, 2)
state = emcee.State(coords)

print(f"Coords shape: {state.coords.shape}")
print(f"Log prob: {state.log_prob}")  # None initially
print(f"Blobs: {state.blobs}")        # None initially

# Create state with log probabilities
log_prob = np.random.randn(32)
state = emcee.State(coords, log_prob=log_prob)
print(f"Log prob shape: {state.log_prob.shape}")

State from Sampling Results

def log_prob(theta):
    return -0.5 * np.sum(theta**2)

# Run sampling
sampler = emcee.EnsembleSampler(32, 2, log_prob)
pos = np.random.randn(32, 2)
final_state = sampler.run_mcmc(pos, 100)

print(f"Final state type: {type(final_state)}")
print(f"Final coords shape: {final_state.coords.shape}")
print(f"Final log_prob shape: {final_state.log_prob.shape}")

# Get last sampled state
last_state = sampler.get_last_sample()
print(f"Same as final state: {np.array_equal(final_state.coords, last_state.coords)}")

State Unpacking (Backward Compatibility)

# State supports tuple unpacking for backward compatibility
state = emcee.State(coords, log_prob=log_prob)

# Unpack without blobs
pos, lp, rstate = state
print(f"Unpacked coords shape: {pos.shape}")
print(f"Unpacked log_prob shape: {lp.shape}")

# With blobs
def log_prob_with_blobs(theta):
    log_p = -0.5 * np.sum(theta**2)
    return log_p, {"energy": np.sum(theta**2)}

sampler_blobs = emcee.EnsembleSampler(32, 2, log_prob_with_blobs)
final_state_blobs = sampler_blobs.run_mcmc(pos, 10)

# Unpack with blobs
pos, lp, rstate, blobs = final_state_blobs
print(f"Blobs type: {type(blobs)}")

State Indexing

state = emcee.State(coords, log_prob=log_prob)

# Access by index
print(f"Index 0 (coords): {state[0].shape}")
print(f"Index 1 (log_prob): {state[1].shape}")  
print(f"Index 2 (random_state): {state[2]}")

# Negative indexing
print(f"Index -1 (last element): {state[-1] is state[2]}")

# Length of state
print(f"State length: {len(state)}")  # 3 without blobs, 4 with blobs

Copying State

# Create state with copy=True for safety
original_coords = np.random.randn(32, 2)
state_copy = emcee.State(original_coords, copy=True)

# Modify original - copied state unchanged
original_coords[0, 0] = 999
print(f"Original modified: {original_coords[0, 0]}")
print(f"Copy unchanged: {state_copy.coords[0, 0]}")

# Create state from another state
state2 = emcee.State(state_copy, copy=True)
print(f"State from state: {np.array_equal(state2.coords, state_copy.coords)}")

Resuming Sampling from State

# Save state for resuming
sampler = emcee.EnsembleSampler(32, 2, log_prob)
pos = np.random.randn(32, 2)

# Initial sampling
intermediate_state = sampler.run_mcmc(pos, 500)
print(f"Completed {sampler.iteration} steps")

# Resume from intermediate state
final_state = sampler.run_mcmc(intermediate_state, 500)
print(f"Total completed: {sampler.iteration} steps")

# State preserves random state for reproducibility
print(f"Random state preserved: {final_state.random_state is not None}")

State with Blobs

def log_prob_detailed(theta):
    log_p = -0.5 * np.sum(theta**2)
    
    # Return detailed metadata
    blobs = {
        "energy": np.sum(theta**2),
        "grad_norm": np.linalg.norm(theta), 
        "param_sum": np.sum(theta)
    }
    return log_p, blobs

sampler = emcee.EnsembleSampler(32, 2, log_prob_detailed)
final_state = sampler.run_mcmc(pos, 100)

print(f"State has blobs: {final_state.blobs is not None}")
print(f"Blob keys: {final_state.blobs.dtype.names if final_state.blobs is not None else 'None'}")

# Access specific blob data
if final_state.blobs is not None:
    energies = final_state.blobs["energy"]
    print(f"Final energies: {energies[:5]}")  # First 5 walkers

Custom State Manipulation

# Create custom state for specific initialization
nwalkers, ndim = 32, 2

# Initialize walkers in specific pattern
coords = np.zeros((nwalkers, ndim))
coords[:nwalkers//2] = np.random.normal(loc=-1, scale=0.5, size=(nwalkers//2, ndim))
coords[nwalkers//2:] = np.random.normal(loc=1, scale=0.5, size=(nwalkers//2, ndim))

# Pre-compute log probabilities
log_probs = np.array([log_prob(coord) for coord in coords])

# Create initialized state
init_state = emcee.State(coords, log_prob=log_probs)

# Use in sampler
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
final_state = sampler.run_mcmc(init_state, 1000)

print(f"Started with precomputed log_probs: {init_state.log_prob is not None}")

State Inspection and Diagnostics

def inspect_state(state, label="State"):
    """Utility function to inspect state contents."""
    
    print(f"\n{label} Inspection:")
    print(f"  Coords shape: {state.coords.shape}")
    print(f"  Coords range: [{np.min(state.coords):.3f}, {np.max(state.coords):.3f}]")
    
    if state.log_prob is not None:
        print(f"  Log prob shape: {state.log_prob.shape}")
        print(f"  Log prob range: [{np.min(state.log_prob):.3f}, {np.max(state.log_prob):.3f}]")
    else:
        print("  Log prob: None")
    
    print(f"  Has blobs: {state.blobs is not None}")
    print(f"  Has random state: {state.random_state is not None}")
    print(f"  State length: {len(state)}")

# Inspect various states
init_state = emcee.State(np.random.randn(32, 2))
inspect_state(init_state, "Initial")

# After sampling
sampler = emcee.EnsembleSampler(32, 2, log_prob)
final_state = sampler.run_mcmc(init_state, 100)
inspect_state(final_state, "Final")

Parallel Processing with States

from multiprocessing import Pool

def log_prob_parallel(theta):
    return -0.5 * np.sum(theta**2)

# State works seamlessly with parallel processing
with Pool() as pool:
    sampler = emcee.EnsembleSampler(32, 2, log_prob_parallel, pool=pool)
    
    # Initialize with state
    init_state = emcee.State(np.random.randn(32, 2))
    final_state = sampler.run_mcmc(init_state, 1000)
    
    print(f"Parallel sampling completed: {final_state.coords.shape}")

Install with Tessl CLI

npx tessl i tessl/pypi-emcee

docs

autocorr.md

backends.md

ensemble-sampling.md

index.md

moves.md

state.md

tile.json