0
# State Management
1
2
The State class in emcee provides a unified interface for handling walker ensemble states during MCMC sampling. It encapsulates walker positions, log probabilities, metadata blobs, and random number generator states, enabling checkpointing, state manipulation, and backward compatibility.
3
4
## Capabilities
5
6
### State Class
7
8
Container for ensemble state information with support for iteration and indexing.
9
10
```python { .api }
11
class State:
12
def __init__(self, coords, log_prob=None, blobs=None, random_state=None,
13
copy: bool = False):
14
"""
15
Initialize ensemble state.
16
17
Args:
18
coords: Walker positions [nwalkers, ndim] or existing State object
19
log_prob: Log probabilities [nwalkers] (optional)
20
blobs: Metadata blobs (optional)
21
random_state: Random number generator state (optional)
22
copy: Whether to deep copy input data (default: False)
23
"""
24
25
coords: np.ndarray # Walker positions [nwalkers, ndim]
26
log_prob: np.ndarray # Log probabilities [nwalkers]
27
blobs: any # Metadata blobs
28
random_state: any # Random number generator state
29
```
30
31
### State Properties and Methods
32
33
Methods for accessing and manipulating state information.
34
35
```python { .api }
36
def __len__(self):
37
"""
38
Get length of state tuple for unpacking.
39
40
Returns:
41
int: 3 if no blobs, 4 if blobs present
42
"""
43
44
def __repr__(self):
45
"""String representation of state."""
46
47
def __iter__(self):
48
"""
49
Iterate over state components for backward compatibility.
50
51
Yields:
52
coords, log_prob, random_state[, blobs]
53
"""
54
55
def __getitem__(self, index: int):
56
"""
57
Access state components by index.
58
59
Args:
60
index: Component index (0=coords, 1=log_prob, 2=random_state, 3=blobs)
61
62
Returns:
63
State component at given index
64
"""
65
```
66
67
## Usage Examples
68
69
### Creating State Objects
70
71
```python
72
import emcee
73
import numpy as np
74
75
# Create state from coordinates only
76
coords = np.random.randn(32, 2)
77
state = emcee.State(coords)
78
79
print(f"Coords shape: {state.coords.shape}")
80
print(f"Log prob: {state.log_prob}") # None initially
81
print(f"Blobs: {state.blobs}") # None initially
82
83
# Create state with log probabilities
84
log_prob = np.random.randn(32)
85
state = emcee.State(coords, log_prob=log_prob)
86
print(f"Log prob shape: {state.log_prob.shape}")
87
```
88
89
### State from Sampling Results
90
91
```python
92
def log_prob(theta):
93
return -0.5 * np.sum(theta**2)
94
95
# Run sampling
96
sampler = emcee.EnsembleSampler(32, 2, log_prob)
97
pos = np.random.randn(32, 2)
98
final_state = sampler.run_mcmc(pos, 100)
99
100
print(f"Final state type: {type(final_state)}")
101
print(f"Final coords shape: {final_state.coords.shape}")
102
print(f"Final log_prob shape: {final_state.log_prob.shape}")
103
104
# Get last sampled state
105
last_state = sampler.get_last_sample()
106
print(f"Same as final state: {np.array_equal(final_state.coords, last_state.coords)}")
107
```
108
109
### State Unpacking (Backward Compatibility)
110
111
```python
112
# State supports tuple unpacking for backward compatibility
113
state = emcee.State(coords, log_prob=log_prob)
114
115
# Unpack without blobs
116
pos, lp, rstate = state
117
print(f"Unpacked coords shape: {pos.shape}")
118
print(f"Unpacked log_prob shape: {lp.shape}")
119
120
# With blobs
121
def log_prob_with_blobs(theta):
122
log_p = -0.5 * np.sum(theta**2)
123
return log_p, {"energy": np.sum(theta**2)}
124
125
sampler_blobs = emcee.EnsembleSampler(32, 2, log_prob_with_blobs)
126
final_state_blobs = sampler_blobs.run_mcmc(pos, 10)
127
128
# Unpack with blobs
129
pos, lp, rstate, blobs = final_state_blobs
130
print(f"Blobs type: {type(blobs)}")
131
```
132
133
### State Indexing
134
135
```python
136
state = emcee.State(coords, log_prob=log_prob)
137
138
# Access by index
139
print(f"Index 0 (coords): {state[0].shape}")
140
print(f"Index 1 (log_prob): {state[1].shape}")
141
print(f"Index 2 (random_state): {state[2]}")
142
143
# Negative indexing
144
print(f"Index -1 (last element): {state[-1] is state[2]}")
145
146
# Length of state
147
print(f"State length: {len(state)}") # 3 without blobs, 4 with blobs
148
```
149
150
### Copying State
151
152
```python
153
# Create state with copy=True for safety
154
original_coords = np.random.randn(32, 2)
155
state_copy = emcee.State(original_coords, copy=True)
156
157
# Modify original - copied state unchanged
158
original_coords[0, 0] = 999
159
print(f"Original modified: {original_coords[0, 0]}")
160
print(f"Copy unchanged: {state_copy.coords[0, 0]}")
161
162
# Create state from another state
163
state2 = emcee.State(state_copy, copy=True)
164
print(f"State from state: {np.array_equal(state2.coords, state_copy.coords)}")
165
```
166
167
### Resuming Sampling from State
168
169
```python
170
# Save state for resuming
171
sampler = emcee.EnsembleSampler(32, 2, log_prob)
172
pos = np.random.randn(32, 2)
173
174
# Initial sampling
175
intermediate_state = sampler.run_mcmc(pos, 500)
176
print(f"Completed {sampler.iteration} steps")
177
178
# Resume from intermediate state
179
final_state = sampler.run_mcmc(intermediate_state, 500)
180
print(f"Total completed: {sampler.iteration} steps")
181
182
# State preserves random state for reproducibility
183
print(f"Random state preserved: {final_state.random_state is not None}")
184
```
185
186
### State with Blobs
187
188
```python
189
def log_prob_detailed(theta):
190
log_p = -0.5 * np.sum(theta**2)
191
192
# Return detailed metadata
193
blobs = {
194
"energy": np.sum(theta**2),
195
"grad_norm": np.linalg.norm(theta),
196
"param_sum": np.sum(theta)
197
}
198
return log_p, blobs
199
200
sampler = emcee.EnsembleSampler(32, 2, log_prob_detailed)
201
final_state = sampler.run_mcmc(pos, 100)
202
203
print(f"State has blobs: {final_state.blobs is not None}")
204
print(f"Blob keys: {final_state.blobs.dtype.names if final_state.blobs is not None else 'None'}")
205
206
# Access specific blob data
207
if final_state.blobs is not None:
208
energies = final_state.blobs["energy"]
209
print(f"Final energies: {energies[:5]}") # First 5 walkers
210
```
211
212
### Custom State Manipulation
213
214
```python
215
# Create custom state for specific initialization
216
nwalkers, ndim = 32, 2
217
218
# Initialize walkers in specific pattern
219
coords = np.zeros((nwalkers, ndim))
220
coords[:nwalkers//2] = np.random.normal(loc=-1, scale=0.5, size=(nwalkers//2, ndim))
221
coords[nwalkers//2:] = np.random.normal(loc=1, scale=0.5, size=(nwalkers//2, ndim))
222
223
# Pre-compute log probabilities
224
log_probs = np.array([log_prob(coord) for coord in coords])
225
226
# Create initialized state
227
init_state = emcee.State(coords, log_prob=log_probs)
228
229
# Use in sampler
230
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
231
final_state = sampler.run_mcmc(init_state, 1000)
232
233
print(f"Started with precomputed log_probs: {init_state.log_prob is not None}")
234
```
235
236
### State Inspection and Diagnostics
237
238
```python
239
def inspect_state(state, label="State"):
240
"""Utility function to inspect state contents."""
241
242
print(f"\n{label} Inspection:")
243
print(f" Coords shape: {state.coords.shape}")
244
print(f" Coords range: [{np.min(state.coords):.3f}, {np.max(state.coords):.3f}]")
245
246
if state.log_prob is not None:
247
print(f" Log prob shape: {state.log_prob.shape}")
248
print(f" Log prob range: [{np.min(state.log_prob):.3f}, {np.max(state.log_prob):.3f}]")
249
else:
250
print(" Log prob: None")
251
252
print(f" Has blobs: {state.blobs is not None}")
253
print(f" Has random state: {state.random_state is not None}")
254
print(f" State length: {len(state)}")
255
256
# Inspect various states
257
init_state = emcee.State(np.random.randn(32, 2))
258
inspect_state(init_state, "Initial")
259
260
# After sampling
261
sampler = emcee.EnsembleSampler(32, 2, log_prob)
262
final_state = sampler.run_mcmc(init_state, 100)
263
inspect_state(final_state, "Final")
264
```
265
266
### Parallel Processing with States
267
268
```python
269
from multiprocessing import Pool
270
271
def log_prob_parallel(theta):
272
return -0.5 * np.sum(theta**2)
273
274
# State works seamlessly with parallel processing
275
with Pool() as pool:
276
sampler = emcee.EnsembleSampler(32, 2, log_prob_parallel, pool=pool)
277
278
# Initialize with state
279
init_state = emcee.State(np.random.randn(32, 2))
280
final_state = sampler.run_mcmc(init_state, 1000)
281
282
print(f"Parallel sampling completed: {final_state.coords.shape}")
283
```