or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

autocorr.mdbackends.mdensemble-sampling.mdindex.mdmoves.mdstate.md

state.mddocs/

0

# State Management

1

2

The State class in emcee provides a unified interface for handling walker ensemble states during MCMC sampling. It encapsulates walker positions, log probabilities, metadata blobs, and random number generator states, enabling checkpointing, state manipulation, and backward compatibility.

3

4

## Capabilities

5

6

### State Class

7

8

Container for ensemble state information with support for iteration and indexing.

9

10

```python { .api }

11

class State:

12

def __init__(self, coords, log_prob=None, blobs=None, random_state=None,

13

copy: bool = False):

14

"""

15

Initialize ensemble state.

16

17

Args:

18

coords: Walker positions [nwalkers, ndim] or existing State object

19

log_prob: Log probabilities [nwalkers] (optional)

20

blobs: Metadata blobs (optional)

21

random_state: Random number generator state (optional)

22

copy: Whether to deep copy input data (default: False)

23

"""

24

25

coords: np.ndarray # Walker positions [nwalkers, ndim]

26

log_prob: np.ndarray # Log probabilities [nwalkers]

27

blobs: any # Metadata blobs

28

random_state: any # Random number generator state

29

```

30

31

### State Properties and Methods

32

33

Methods for accessing and manipulating state information.

34

35

```python { .api }

36

def __len__(self):

37

"""

38

Get length of state tuple for unpacking.

39

40

Returns:

41

int: 3 if no blobs, 4 if blobs present

42

"""

43

44

def __repr__(self):

45

"""String representation of state."""

46

47

def __iter__(self):

48

"""

49

Iterate over state components for backward compatibility.

50

51

Yields:

52

coords, log_prob, random_state[, blobs]

53

"""

54

55

def __getitem__(self, index: int):

56

"""

57

Access state components by index.

58

59

Args:

60

index: Component index (0=coords, 1=log_prob, 2=random_state, 3=blobs)

61

62

Returns:

63

State component at given index

64

"""

65

```

66

67

## Usage Examples

68

69

### Creating State Objects

70

71

```python

72

import emcee

73

import numpy as np

74

75

# Create state from coordinates only

76

coords = np.random.randn(32, 2)

77

state = emcee.State(coords)

78

79

print(f"Coords shape: {state.coords.shape}")

80

print(f"Log prob: {state.log_prob}") # None initially

81

print(f"Blobs: {state.blobs}") # None initially

82

83

# Create state with log probabilities

84

log_prob = np.random.randn(32)

85

state = emcee.State(coords, log_prob=log_prob)

86

print(f"Log prob shape: {state.log_prob.shape}")

87

```

88

89

### State from Sampling Results

90

91

```python

92

def log_prob(theta):

93

return -0.5 * np.sum(theta**2)

94

95

# Run sampling

96

sampler = emcee.EnsembleSampler(32, 2, log_prob)

97

pos = np.random.randn(32, 2)

98

final_state = sampler.run_mcmc(pos, 100)

99

100

print(f"Final state type: {type(final_state)}")

101

print(f"Final coords shape: {final_state.coords.shape}")

102

print(f"Final log_prob shape: {final_state.log_prob.shape}")

103

104

# Get last sampled state

105

last_state = sampler.get_last_sample()

106

print(f"Same as final state: {np.array_equal(final_state.coords, last_state.coords)}")

107

```

108

109

### State Unpacking (Backward Compatibility)

110

111

```python

112

# State supports tuple unpacking for backward compatibility

113

state = emcee.State(coords, log_prob=log_prob)

114

115

# Unpack without blobs

116

pos, lp, rstate = state

117

print(f"Unpacked coords shape: {pos.shape}")

118

print(f"Unpacked log_prob shape: {lp.shape}")

119

120

# With blobs

121

def log_prob_with_blobs(theta):

122

log_p = -0.5 * np.sum(theta**2)

123

return log_p, {"energy": np.sum(theta**2)}

124

125

sampler_blobs = emcee.EnsembleSampler(32, 2, log_prob_with_blobs)

126

final_state_blobs = sampler_blobs.run_mcmc(pos, 10)

127

128

# Unpack with blobs

129

pos, lp, rstate, blobs = final_state_blobs

130

print(f"Blobs type: {type(blobs)}")

131

```

132

133

### State Indexing

134

135

```python

136

state = emcee.State(coords, log_prob=log_prob)

137

138

# Access by index

139

print(f"Index 0 (coords): {state[0].shape}")

140

print(f"Index 1 (log_prob): {state[1].shape}")

141

print(f"Index 2 (random_state): {state[2]}")

142

143

# Negative indexing

144

print(f"Index -1 (last element): {state[-1] is state[2]}")

145

146

# Length of state

147

print(f"State length: {len(state)}") # 3 without blobs, 4 with blobs

148

```

149

150

### Copying State

151

152

```python

153

# Create state with copy=True for safety

154

original_coords = np.random.randn(32, 2)

155

state_copy = emcee.State(original_coords, copy=True)

156

157

# Modify original - copied state unchanged

158

original_coords[0, 0] = 999

159

print(f"Original modified: {original_coords[0, 0]}")

160

print(f"Copy unchanged: {state_copy.coords[0, 0]}")

161

162

# Create state from another state

163

state2 = emcee.State(state_copy, copy=True)

164

print(f"State from state: {np.array_equal(state2.coords, state_copy.coords)}")

165

```

166

167

### Resuming Sampling from State

168

169

```python

170

# Save state for resuming

171

sampler = emcee.EnsembleSampler(32, 2, log_prob)

172

pos = np.random.randn(32, 2)

173

174

# Initial sampling

175

intermediate_state = sampler.run_mcmc(pos, 500)

176

print(f"Completed {sampler.iteration} steps")

177

178

# Resume from intermediate state

179

final_state = sampler.run_mcmc(intermediate_state, 500)

180

print(f"Total completed: {sampler.iteration} steps")

181

182

# State preserves random state for reproducibility

183

print(f"Random state preserved: {final_state.random_state is not None}")

184

```

185

186

### State with Blobs

187

188

```python

189

def log_prob_detailed(theta):

190

log_p = -0.5 * np.sum(theta**2)

191

192

# Return detailed metadata

193

blobs = {

194

"energy": np.sum(theta**2),

195

"grad_norm": np.linalg.norm(theta),

196

"param_sum": np.sum(theta)

197

}

198

return log_p, blobs

199

200

sampler = emcee.EnsembleSampler(32, 2, log_prob_detailed)

201

final_state = sampler.run_mcmc(pos, 100)

202

203

print(f"State has blobs: {final_state.blobs is not None}")

204

print(f"Blob keys: {final_state.blobs.dtype.names if final_state.blobs is not None else 'None'}")

205

206

# Access specific blob data

207

if final_state.blobs is not None:

208

energies = final_state.blobs["energy"]

209

print(f"Final energies: {energies[:5]}") # First 5 walkers

210

```

211

212

### Custom State Manipulation

213

214

```python

215

# Create custom state for specific initialization

216

nwalkers, ndim = 32, 2

217

218

# Initialize walkers in specific pattern

219

coords = np.zeros((nwalkers, ndim))

220

coords[:nwalkers//2] = np.random.normal(loc=-1, scale=0.5, size=(nwalkers//2, ndim))

221

coords[nwalkers//2:] = np.random.normal(loc=1, scale=0.5, size=(nwalkers//2, ndim))

222

223

# Pre-compute log probabilities

224

log_probs = np.array([log_prob(coord) for coord in coords])

225

226

# Create initialized state

227

init_state = emcee.State(coords, log_prob=log_probs)

228

229

# Use in sampler

230

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)

231

final_state = sampler.run_mcmc(init_state, 1000)

232

233

print(f"Started with precomputed log_probs: {init_state.log_prob is not None}")

234

```

235

236

### State Inspection and Diagnostics

237

238

```python

239

def inspect_state(state, label="State"):

240

"""Utility function to inspect state contents."""

241

242

print(f"\n{label} Inspection:")

243

print(f" Coords shape: {state.coords.shape}")

244

print(f" Coords range: [{np.min(state.coords):.3f}, {np.max(state.coords):.3f}]")

245

246

if state.log_prob is not None:

247

print(f" Log prob shape: {state.log_prob.shape}")

248

print(f" Log prob range: [{np.min(state.log_prob):.3f}, {np.max(state.log_prob):.3f}]")

249

else:

250

print(" Log prob: None")

251

252

print(f" Has blobs: {state.blobs is not None}")

253

print(f" Has random state: {state.random_state is not None}")

254

print(f" State length: {len(state)}")

255

256

# Inspect various states

257

init_state = emcee.State(np.random.randn(32, 2))

258

inspect_state(init_state, "Initial")

259

260

# After sampling

261

sampler = emcee.EnsembleSampler(32, 2, log_prob)

262

final_state = sampler.run_mcmc(init_state, 100)

263

inspect_state(final_state, "Final")

264

```

265

266

### Parallel Processing with States

267

268

```python

269

from multiprocessing import Pool

270

271

def log_prob_parallel(theta):

272

return -0.5 * np.sum(theta**2)

273

274

# State works seamlessly with parallel processing

275

with Pool() as pool:

276

sampler = emcee.EnsembleSampler(32, 2, log_prob_parallel, pool=pool)

277

278

# Initialize with state

279

init_state = emcee.State(np.random.randn(32, 2))

280

final_state = sampler.run_mcmc(init_state, 1000)

281

282

print(f"Parallel sampling completed: {final_state.coords.shape}")

283

```