0
# PyMC Ordinary Differential Equations (ODE)
1
2
PyMC provides comprehensive support for probabilistic modeling with ordinary differential equations. This enables Bayesian parameter estimation for dynamic systems, time series modeling with mechanistic models, and uncertainty quantification in scientific models described by ODEs.
3
4
## Core ODE Interface
5
6
The ODE module integrates differential equation solvers with PyMC's probabilistic programming framework:
7
8
```python { .api }
9
from pymc import ode
10
import numpy as np
11
12
# Main ODE class
13
class ode.DifferentialEquation:
14
"""
15
Solve ordinary differential equations as part of PyMC models.
16
17
Parameters:
18
- func (callable): Function defining the ODE system dy/dt = func(y, t, theta)
19
- times (array): Time points at which to solve the ODE
20
- n_states (int): Number of state variables in the ODE system
21
- n_theta (int): Number of parameters in the ODE system
22
- t0 (float): Initial time point
23
"""
24
25
def __init__(self, func, times, n_states, n_theta, t0=0):
26
pass
27
28
def __call__(self, theta, y0):
29
"""
30
Solve the ODE system.
31
32
Parameters:
33
- theta: Parameter tensor for the ODE system
34
- y0: Initial conditions tensor
35
36
Returns:
37
- Solution tensor with shape (n_times, n_states)
38
"""
39
pass
40
```
41
42
## Capabilities
43
44
### Basic ODE Integration
45
46
Solve systems of ordinary differential equations within probabilistic models:
47
48
```python { .api }
49
import pymc as pm
50
from pymc import ode
51
import numpy as np
52
53
# Define ODE system
54
def simple_exponential(y, t, theta):
55
"""Simple exponential growth: dy/dt = theta[0] * y"""
56
return theta[0] * y
57
58
# Solve in model context
59
with pm.Model() as exponential_model:
60
# Parameter prior
61
growth_rate = pm.Normal('growth_rate', mu=0.1, sigma=0.05)
62
63
# Initial condition
64
y0 = pm.Normal('y0', mu=1.0, sigma=0.1)
65
66
# Define ODE
67
ode_solution = ode.DifferentialEquation(
68
func=simple_exponential,
69
times=np.linspace(0, 10, 50),
70
n_states=1,
71
n_theta=1,
72
t0=0
73
)
74
75
# Solve ODE
76
solution = ode_solution([growth_rate], [y0])
77
78
# Likelihood for observed data
79
sigma = pm.HalfNormal('sigma', sigma=0.1)
80
obs = pm.Normal('obs', mu=solution, sigma=sigma, observed=observed_data)
81
```
82
83
### Multi-Dimensional ODE Systems
84
85
Handle complex dynamical systems with multiple interacting components:
86
87
```python { .api }
88
def lotka_volterra(y, t, theta):
89
"""
90
Lotka-Volterra predator-prey model.
91
92
y[0]: prey population
93
y[1]: predator population
94
theta[0]: prey growth rate (alpha)
95
theta[1]: predation rate (beta)
96
theta[2]: predator efficiency (gamma)
97
theta[3]: predator death rate (delta)
98
"""
99
prey, predator = y
100
alpha, beta, gamma, delta = theta
101
102
dprey_dt = alpha * prey - beta * prey * predator
103
dpredator_dt = gamma * prey * predator - delta * predator
104
105
return [dprey_dt, dpredator_dt]
106
107
# Model with multiple state variables
108
with pm.Model() as lotka_volterra_model:
109
# Population dynamics parameters
110
alpha = pm.LogNormal('alpha', mu=np.log(1.0), sigma=0.2) # Prey growth
111
beta = pm.LogNormal('beta', mu=np.log(0.5), sigma=0.2) # Predation
112
gamma = pm.LogNormal('gamma', mu=np.log(0.3), sigma=0.2) # Predator efficiency
113
delta = pm.LogNormal('delta', mu=np.log(0.8), sigma=0.2) # Predator death
114
115
# Initial populations
116
prey_0 = pm.LogNormal('prey_0', mu=np.log(10), sigma=0.1)
117
predator_0 = pm.LogNormal('predator_0', mu=np.log(5), sigma=0.1)
118
119
# Solve the ODE system
120
ode_solution = ode.DifferentialEquation(
121
func=lotka_volterra,
122
times=observation_times,
123
n_states=2,
124
n_theta=4,
125
t0=0
126
)
127
128
solution = ode_solution([alpha, beta, gamma, delta], [prey_0, predator_0])
129
130
# Separate observed populations
131
prey_solution = solution[:, 0]
132
predator_solution = solution[:, 1]
133
134
# Observation models
135
prey_sigma = pm.HalfNormal('prey_sigma', sigma=1.0)
136
predator_sigma = pm.HalfNormal('predator_sigma', sigma=1.0)
137
138
prey_obs = pm.Normal('prey_obs', mu=prey_solution, sigma=prey_sigma,
139
observed=prey_data)
140
predator_obs = pm.Normal('predator_obs', mu=predator_solution,
141
sigma=predator_sigma, observed=predator_data)
142
```
143
144
### Pharmacokinetic/Pharmacodynamic Models
145
146
Model drug concentration and effect over time:
147
148
```python { .api }
149
def pk_two_compartment(y, t, theta):
150
"""
151
Two-compartment pharmacokinetic model.
152
153
y[0]: central compartment concentration
154
y[1]: peripheral compartment concentration
155
theta[0]: elimination rate (ke)
156
theta[1]: distribution rate central->peripheral (k12)
157
theta[2]: distribution rate peripheral->central (k21)
158
"""
159
central, peripheral = y
160
ke, k12, k21 = theta
161
162
dcentral_dt = -ke * central - k12 * central + k21 * peripheral
163
dperipheral_dt = k12 * central - k21 * peripheral
164
165
return [dcentral_dt, dperipheral_dt]
166
167
# Pharmacokinetic model with dosing
168
with pm.Model() as pk_model:
169
# PK parameters
170
ke = pm.LogNormal('ke', mu=np.log(0.1), sigma=0.3) # Elimination rate
171
k12 = pm.LogNormal('k12', mu=np.log(0.05), sigma=0.3) # Central to peripheral
172
k21 = pm.LogNormal('k21', mu=np.log(0.03), sigma=0.3) # Peripheral to central
173
174
# Initial concentrations (after IV dose)
175
dose = 100 # mg
176
volume_central = 10 # L
177
initial_central = dose / volume_central
178
initial_peripheral = 0.0
179
180
# Solve ODE
181
ode_solution = ode.DifferentialEquation(
182
func=pk_two_compartment,
183
times=sample_times,
184
n_states=2,
185
n_theta=3,
186
t0=0
187
)
188
189
concentrations = ode_solution([ke, k12, k21],
190
[initial_central, initial_peripheral])
191
192
# Observable is central compartment concentration
193
central_conc = concentrations[:, 0]
194
195
# Proportional error model
196
sigma_prop = pm.HalfNormal('sigma_prop', sigma=0.1)
197
observed_conc = pm.Normal('observed_conc',
198
mu=central_conc,
199
sigma=sigma_prop * central_conc,
200
observed=concentration_data)
201
```
202
203
### Epidemiological Models
204
205
Model disease spread using compartmental models:
206
207
```python { .api }
208
def sir_model(y, t, theta):
209
"""
210
SIR epidemiological model.
211
212
y[0]: susceptible population (S)
213
y[1]: infectious population (I)
214
y[2]: recovered population (R)
215
theta[0]: transmission rate (beta)
216
theta[1]: recovery rate (gamma)
217
"""
218
S, I, R = y
219
beta, gamma = theta
220
N = S + I + R # Total population
221
222
dS_dt = -beta * S * I / N
223
dI_dt = beta * S * I / N - gamma * I
224
dR_dt = gamma * I
225
226
return [dS_dt, dI_dt, dR_dt]
227
228
# Epidemiological model
229
with pm.Model() as sir_epidemic:
230
# Disease parameters
231
beta = pm.LogNormal('beta', mu=np.log(0.3), sigma=0.2) # Transmission rate
232
gamma = pm.LogNormal('gamma', mu=np.log(0.1), sigma=0.2) # Recovery rate
233
234
# Initial conditions
235
N = 1000000 # Total population
236
I0 = 10 # Initial infected
237
S0 = N - I0 # Initial susceptible
238
R0 = 0 # Initial recovered
239
240
# Solve epidemic dynamics
241
ode_solution = ode.DifferentialEquation(
242
func=sir_model,
243
times=np.arange(0, 365), # One year daily
244
n_states=3,
245
n_theta=2,
246
t0=0
247
)
248
249
solution = ode_solution([beta, gamma], [S0, I0, R0])
250
251
# Extract infectious population over time
252
I_t = solution[:, 1]
253
254
# Observation model for reported cases
255
reporting_rate = pm.Beta('reporting_rate', alpha=2, beta=8)
256
expected_reports = I_t * reporting_rate
257
258
# Negative binomial for overdispersed count data
259
alpha = pm.HalfNormal('alpha', sigma=10)
260
reported_cases = pm.NegativeBinomial('reported_cases',
261
mu=expected_reports,
262
alpha=alpha,
263
observed=case_data)
264
```
265
266
## Advanced Features
267
268
### Time-Varying Parameters
269
270
Handle parameters that change over time:
271
272
```python { .api }
273
def time_varying_ode(y, t, theta_func):
274
"""ODE with time-varying parameters."""
275
theta_t = theta_func(t) # Parameters as function of time
276
return theta_t[0] * y - theta_t[1] * y**2
277
278
# Implementation with interpolated parameters
279
with pm.Model() as time_varying_model:
280
# Time-varying growth rate
281
growth_nodes = pm.Normal('growth_nodes', mu=0.1, sigma=0.05, shape=5)
282
283
# Create interpolation function (conceptual - actual implementation varies)
284
def theta_interpolated(t):
285
# Interpolate growth_nodes over time
286
return [interpolate_growth(t, growth_nodes), 0.01]
287
288
# Use in ODE solving...
289
```
290
291
### Stochastic Differential Equations
292
293
While PyMC's ODE module focuses on deterministic ODEs, stochastic elements can be incorporated through observation models and parameter uncertainty.
294
295
### Integration with Experimental Design
296
297
Use ODE models to optimize experimental design:
298
299
```python { .api }
300
# Design optimal sampling times for parameter estimation
301
with pm.Model() as design_model:
302
# Parameters
303
theta = pm.LogNormal('theta', mu=0, sigma=1, shape=n_params)
304
305
# Candidate sampling times
306
sample_times_candidates = pm.Uniform('sample_times',
307
lower=0, upper=T_max,
308
shape=n_samples)
309
310
# Solve ODE at candidate times
311
ode_solution = ode.DifferentialEquation(
312
func=model_equations,
313
times=sample_times_candidates,
314
n_states=n_states,
315
n_theta=n_params,
316
t0=0
317
)
318
319
predictions = ode_solution(theta, y0)
320
321
# Expected information criterion for design optimization
322
# (Implementation depends on specific design criterion)
323
```
324
325
## Performance Considerations
326
327
### Solver Selection
328
329
PyMC's ODE module uses efficient numerical solvers. For optimal performance:
330
331
```python { .api }
332
# Install sunode for high-performance solving
333
# pip install sunode
334
335
# Configure solver settings (conceptual)
336
ode_solution = ode.DifferentialEquation(
337
func=model_equations,
338
times=times,
339
n_states=n_states,
340
n_theta=n_theta,
341
solver='dopri5', # Runge-Kutta solver
342
rtol=1e-6, # Relative tolerance
343
atol=1e-8 # Absolute tolerance
344
)
345
```
346
347
### Computational Efficiency
348
349
- Use compiled functions when possible
350
- Minimize the number of time points when not all are needed for observations
351
- Consider gradient-based sampling methods (NUTS) for efficient exploration
352
- Profile your models to identify computational bottlenecks
353
354
## Common Patterns
355
356
### Parameter Transformations
357
358
```python { .api }
359
# Log-transform positive parameters for unconstrained sampling
360
log_params = pm.Normal('log_params', mu=0, sigma=1, shape=n_params)
361
params = pm.math.exp(log_params)
362
363
# Use transformed parameters in ODE
364
solution = ode_solution(params, y0)
365
```
366
367
### Missing Data and Irregular Observations
368
369
```python { .api }
370
# Handle missing observations
371
mask = ~np.isnan(observed_data)
372
solution_observed = solution[mask]
373
data_observed = observed_data[mask]
374
375
obs = pm.Normal('obs', mu=solution_observed, sigma=sigma, observed=data_observed)
376
```
377
378
### Model Comparison
379
380
```python { .api }
381
# Compare different ODE structures
382
with pm.Model() as model_1:
383
# Simple exponential growth
384
pass
385
386
with pm.Model() as model_2:
387
# Logistic growth with carrying capacity
388
pass
389
390
# Use information criteria for model selection
391
loo_1 = pm.loo(trace_1, model_1)
392
loo_2 = pm.loo(trace_2, model_2)
393
pm.compare({'exponential': loo_1, 'logistic': loo_2})
394
```
395
396
PyMC's ODE integration provides a powerful framework for mechanistic modeling in a Bayesian context, enabling principled uncertainty quantification for dynamic systems across scientific domains.