0
# Utilities
1
2
NumPyro provides essential utility functions for JAX configuration, control flow primitives, model validation, and development helpers. These utilities enable efficient probabilistic programming with proper hardware acceleration, memory management, and debugging capabilities.
3
4
## Capabilities
5
6
### JAX Configuration
7
8
Functions for configuring JAX behavior and hardware acceleration.
9
10
```python { .api }
11
def enable_x64(use_x64: bool = True) -> None:
12
"""
13
Enable or disable 64-bit precision for JAX computations.
14
15
By default, JAX uses 32-bit precision for performance. Enable 64-bit
16
precision when higher numerical accuracy is needed.
17
18
Args:
19
use_x64: Whether to use 64-bit precision (default: True)
20
21
Usage:
22
# Enable double precision for numerical stability
23
numpyro.enable_x64(True)
24
25
# Disable to return to 32-bit (faster but less precise)
26
numpyro.enable_x64(False)
27
28
# Check current precision
29
import jax
30
print(f"Current precision: {jax.config.jax_enable_x64}")
31
"""
32
33
def set_platform(platform: Optional[str] = None) -> None:
34
"""
35
Set the JAX platform for computations.
36
37
Args:
38
platform: Platform name ('cpu', 'gpu', 'tpu', or None for auto-detection)
39
40
Usage:
41
# Force CPU computation
42
numpyro.set_platform('cpu')
43
44
# Use GPU if available
45
numpyro.set_platform('gpu')
46
47
# Let JAX auto-detect best platform
48
numpyro.set_platform(None)
49
50
# Check current platform
51
import jax
52
print(f"Current platform: {jax.default_backend()}")
53
"""
54
55
def set_host_device_count(n: int) -> None:
56
"""
57
Set the number of CPU devices for parallel computation.
58
59
Useful for parallelizing MCMC chains across multiple CPU cores
60
when GPU is not available or desired.
61
62
Args:
63
n: Number of CPU devices to use
64
65
Usage:
66
# Use 4 CPU devices for parallel chains
67
numpyro.set_host_device_count(4)
68
69
# Then run MCMC with multiple chains
70
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
71
mcmc.run(rng_key, data) # Will use 4 CPU devices
72
"""
73
74
def set_rng_seed(rng_seed: Optional[int] = None) -> None:
75
"""
76
Set global random seed for reproducible results.
77
78
Args:
79
rng_seed: Random seed value (None to use system entropy)
80
81
Usage:
82
# Set seed for reproducible experiments
83
numpyro.set_rng_seed(42)
84
85
# Clear seed to use random initialization
86
numpyro.set_rng_seed(None)
87
"""
88
```
89
90
### Control Flow Primitives
91
92
JAX-compatible control flow functions for probabilistic programs.
93
94
```python { .api }
95
def cond(pred: ArrayLike, true_operand: Any, true_fun: Callable,
96
false_operand: Any, false_fun: Callable) -> Any:
97
"""
98
JAX-compatible conditional execution primitive.
99
100
Provides structured control flow that works with JAX transformations
101
like JIT compilation and automatic differentiation.
102
103
Args:
104
pred: Boolean condition for branching
105
true_operand: Operand passed to true_fun if pred is True
106
true_fun: Function to call if pred is True
107
false_operand: Operand passed to false_fun if pred is False
108
false_fun: Function to call if pred is False
109
110
Returns:
111
Result of the executed branch
112
113
Usage:
114
def model(x):
115
# Conditional model structure
116
def high_noise_model(x):
117
return numpyro.sample("y", dist.Normal(x, 2.0))
118
119
def low_noise_model(x):
120
return numpyro.sample("y", dist.Normal(x, 0.1))
121
122
# Switch based on input value
123
is_high = x > 0.5
124
return numpyro.cond(is_high, x, high_noise_model, x, low_noise_model)
125
"""
126
127
def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any:
128
"""
129
JAX-compatible while loop primitive.
130
131
Executes body_fun repeatedly while cond_fun returns True.
132
Compatible with JAX transformations.
133
134
Args:
135
cond_fun: Function that takes loop state and returns boolean
136
body_fun: Function that takes loop state and returns new state
137
init_val: Initial loop state
138
139
Returns:
140
Final loop state
141
142
Usage:
143
def iterative_sampler(key, n_steps):
144
def cond_fun(state):
145
step, _, _ = state
146
return step < n_steps
147
148
def body_fun(state):
149
step, key, samples = state
150
key, subkey = random.split(key)
151
new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))
152
return step + 1, key, samples.at[step].set(new_sample)
153
154
init_samples = jnp.zeros(n_steps)
155
_, _, final_samples = numpyro.while_loop(
156
cond_fun, body_fun, (0, key, init_samples)
157
)
158
return final_samples
159
"""
160
161
def fori_loop(lower: int, upper: int, body_fun: Callable, init_val: Any) -> Any:
162
"""
163
JAX-compatible for loop primitive.
164
165
Executes body_fun for indices from lower to upper-1.
166
167
Args:
168
lower: Starting index (inclusive)
169
upper: Ending index (exclusive)
170
body_fun: Function that takes (index, state) and returns new state
171
init_val: Initial loop state
172
173
Returns:
174
Final loop state
175
176
Usage:
177
def accumulate_samples(key, n_samples):
178
def body_fun(i, state):
179
key, total = state
180
key, subkey = random.split(key)
181
sample = random.normal(subkey)
182
return key, total + sample
183
184
key, final_total = numpyro.fori_loop(0, n_samples, body_fun, (key, 0.0))
185
return final_total / n_samples
186
"""
187
```
188
189
### Memory-Efficient Utilities
190
191
Functions for managing memory usage in large-scale computations.
192
193
```python { .api }
194
def soft_vmap(fn: Callable, xs: ArrayLike, batch_ndims: int = 1,
195
chunk_size: Optional[int] = None) -> ArrayLike:
196
"""
197
Memory-efficient vectorized map that processes data in chunks.
198
199
Alternative to jax.vmap that avoids memory issues with large datasets
200
by processing inputs in smaller chunks.
201
202
Args:
203
fn: Function to vectorize
204
xs: Input arrays to map over
205
batch_ndims: Number of batch dimensions to map over
206
chunk_size: Size of chunks to process (None for auto-selection)
207
208
Returns:
209
Vectorized results concatenated from chunks
210
211
Usage:
212
# Process large dataset without memory overflow
213
def expensive_computation(x):
214
return x @ weight_matrix # Large matrix multiplication
215
216
large_data = jnp.ones((10000, 1000)) # Would cause OOM with vmap
217
218
# Process in chunks
219
results = numpyro.soft_vmap(expensive_computation, large_data, chunk_size=100)
220
# Shape: (10000, output_dim)
221
"""
222
223
def fori_collect(lower: int, upper: int, body_fun: Callable, init_val: Any,
224
transform: Optional[Callable] = None, progbar: bool = True,
225
return_last_val: bool = False, collection_size: Optional[int] = None,
226
**progbar_opts) -> Union[tuple, ArrayLike]:
227
"""
228
For loop with collection and optional progress bar.
229
230
Collects outputs from each iteration while optionally displaying progress.
231
Useful for iterative algorithms where you need to track intermediate results.
232
233
Args:
234
lower: Starting index
235
upper: Ending index
236
body_fun: Function returning (new_state, collection_item)
237
init_val: Initial state
238
transform: Optional transform applied to collected items
239
progbar: Whether to show progress bar
240
return_last_val: Whether to return final state
241
collection_size: Pre-allocate collection array size
242
**progbar_opts: Additional progress bar options
243
244
Returns:
245
Collection of items (and optionally final state)
246
247
Usage:
248
# Collect MCMC samples with progress tracking
249
def mcmc_step(i, state):
250
key, params = state
251
key, subkey = random.split(key)
252
253
# Single MCMC step
254
new_params = mcmc_kernel_step(subkey, params)
255
256
return (key, new_params), new_params # (new_state, collect_item)
257
258
init_state = (random.PRNGKey(0), init_params)
259
samples = numpyro.fori_collect(0, 1000, mcmc_step, init_state, progbar=True)
260
"""
261
```
262
263
### Model Validation and Debugging
264
265
Utilities for validating models and debugging probabilistic programs.
266
267
```python { .api }
268
def format_shapes(trace: dict, last_site: Optional[str] = None) -> str:
269
"""
270
Format trace shapes for debugging model structure.
271
272
Provides a readable summary of all sites in a model trace with their
273
shapes, which is useful for debugging broadcasting and plate issues.
274
275
Args:
276
trace: Execution trace from model
277
last_site: Name of last site to include (None for all sites)
278
279
Returns:
280
Formatted string showing site shapes
281
282
Usage:
283
# Debug model shapes
284
from numpyro.handlers import trace
285
286
def model():
287
with numpyro.plate("batch", 10):
288
x = numpyro.sample("x", dist.Normal(0, 1)) # Should be (10,)
289
with numpyro.plate("features", 5):
290
y = numpyro.sample("y", dist.Normal(x.expand((5,)), 1)) # Should be (10, 5)
291
292
traced_model = trace(model)
293
trace_dict = traced_model()
294
295
shape_info = numpyro.format_shapes(trace_dict)
296
print(shape_info)
297
# Output:
298
# Site shapes:
299
# x: (10,)
300
# y: (10, 5)
301
"""
302
303
def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None:
304
"""
305
Validate that model and guide have compatible structure.
306
307
Ensures that the guide provides variational distributions for all
308
sample sites in the model, which is required for SVI.
309
310
Args:
311
model_trace: Trace from model execution
312
guide_trace: Trace from guide execution
313
314
Raises:
315
ValueError: If model and guide are incompatible
316
317
Usage:
318
# Validate model-guide compatibility before SVI
319
from numpyro.handlers import trace
320
321
model_trace = trace(model).get_trace(data)
322
guide_trace = trace(guide).get_trace(data)
323
324
try:
325
numpyro.check_model_guide_match(model_trace, guide_trace)
326
print("✓ Model and guide are compatible")
327
except ValueError as e:
328
print(f"✗ Compatibility error: {e}")
329
"""
330
331
def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:
332
"""
333
Comprehensive model validation and structure analysis.
334
335
Args:
336
model: Model function to validate
337
*model_args: Arguments to pass to model
338
**model_kwargs: Keyword arguments to pass to model
339
340
Returns:
341
Dictionary containing validation results and model information
342
343
Usage:
344
def my_model(x, y=None):
345
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
346
with numpyro.plate("data", len(x)):
347
numpyro.sample("y", dist.Normal(alpha + x, 1), obs=y)
348
349
x_data = jnp.linspace(0, 1, 100)
350
validation = numpyro.validate_model(my_model, x_data)
351
352
print(f"Number of sample sites: {len(validation['sample_sites'])}")
353
print(f"Model structure: {validation['structure']}")
354
print(f"Validation passed: {validation['is_valid']}")
355
"""
356
```
357
358
### Development and Performance Utilities
359
360
Helper functions for development and performance optimization.
361
362
```python { .api }
363
def maybe_jit(fn: Callable, *args, **kwargs) -> Callable:
364
"""
365
Conditionally apply JIT compilation based on context.
366
367
Automatically determines whether to JIT compile based on the computational
368
context and function characteristics.
369
370
Args:
371
fn: Function to potentially JIT compile
372
*args: Arguments that would be passed to function
373
**kwargs: Keyword arguments
374
375
Returns:
376
JIT-compiled or original function
377
378
Usage:
379
# Automatically optimize based on usage pattern
380
def expensive_computation(x):
381
return jnp.sum(x ** 2)
382
383
optimized_fn = numpyro.maybe_jit(expensive_computation)
384
result = optimized_fn(large_array) # Will be JIT compiled if beneficial
385
"""
386
387
def progress_bar_factory(num_samples: int, num_chains: int = 1) -> Callable:
388
"""
389
Create progress bar decorators for iterative algorithms.
390
391
Args:
392
num_samples: Total number of samples/iterations
393
num_chains: Number of parallel chains
394
395
Returns:
396
Progress bar decorator function
397
398
Usage:
399
# Add progress bars to custom sampling loops
400
progress_bar = numpyro.progress_bar_factory(1000, num_chains=4)
401
402
@progress_bar
403
def sampling_step(i, state):
404
# Custom sampling logic
405
return new_state
406
407
# Progress will be displayed automatically
408
final_state = fori_loop(0, 1000, sampling_step, init_state)
409
"""
410
411
def cached_by(outer_fn: Callable, *keys) -> Callable:
412
"""
413
Function caching decorator with custom cache keys.
414
415
Caches function results based on specified keys to avoid recomputation
416
of expensive operations.
417
418
Args:
419
outer_fn: Function to cache
420
*keys: Keys to use for cache lookup
421
422
Returns:
423
Cached version of the function
424
425
Usage:
426
# Cache expensive model compilations
427
@numpyro.cached_by(lambda model, data_shape: (model.__name__, data_shape))
428
def compile_model(model, data_shape):
429
# Expensive JIT compilation
430
return jit(model)
431
432
compiled_model = compile_model(my_model, (100,)) # Compiled once
433
compiled_model = compile_model(my_model, (100,)) # Retrieved from cache
434
"""
435
436
def identity(x: Any, *args, **kwargs) -> Any:
437
"""
438
Identity function that returns input unchanged.
439
440
Useful as a placeholder or default function in conditional contexts.
441
442
Args:
443
x: Input value
444
*args: Ignored additional arguments
445
**kwargs: Ignored keyword arguments
446
447
Returns:
448
Input value unchanged
449
"""
450
451
def not_jax_tracer(x: Any) -> bool:
452
"""
453
Check if value is not a JAX tracer.
454
455
Useful for conditional logic that depends on whether values are
456
concrete or abstract (traced) in JAX transformations.
457
458
Args:
459
x: Value to check
460
461
Returns:
462
True if x is not a JAX tracer, False otherwise
463
464
Usage:
465
def conditional_computation(x):
466
if numpyro.not_jax_tracer(x):
467
# This branch only executes with concrete values
468
print(f"Concrete value: {x}")
469
return x ** 2
470
"""
471
472
def is_prng_key(key: Any) -> bool:
473
"""
474
Validate that input is a proper PRNG key.
475
476
Args:
477
key: Potential PRNG key to validate
478
479
Returns:
480
True if key is a valid PRNG key
481
482
Usage:
483
from jax import random
484
485
key = random.PRNGKey(0)
486
if numpyro.is_prng_key(key):
487
subkey = random.split(key)[0]
488
else:
489
raise ValueError("Invalid PRNG key")
490
"""
491
```
492
493
### Context Managers and Control
494
495
Utilities for context management and execution control.
496
497
```python { .api }
498
def optional(condition: bool, context_manager: Any) -> Any:
499
"""
500
Conditionally apply a context manager.
501
502
Args:
503
condition: Whether to apply the context manager
504
context_manager: Context manager to apply if condition is True
505
506
Returns:
507
Context manager or no-op context
508
509
Usage:
510
# Conditionally enable validation
511
use_validation = True
512
513
with numpyro.optional(use_validation, numpyro.validation_enabled()):
514
result = model() # Validation applied only if use_validation=True
515
"""
516
517
def control_flow_prims_disabled() -> bool:
518
"""
519
Check if control flow primitives are disabled.
520
521
Returns:
522
True if control flow primitives (cond, while_loop) are disabled
523
524
Usage:
525
if numpyro.control_flow_prims_disabled():
526
# Use alternative implementation without control flow
527
result = alternative_implementation()
528
else:
529
result = numpyro.cond(pred, true_op, true_fn, false_op, false_fn)
530
"""
531
532
def nested_attrgetter(*collect_fields: str) -> Callable:
533
"""
534
Create getter for nested attributes in complex data structures.
535
536
Args:
537
*collect_fields: Dot-separated field paths to extract
538
539
Returns:
540
Function that extracts specified fields from objects
541
542
Usage:
543
# Extract nested fields from complex results
544
getter = numpyro.nested_attrgetter("params.mu.loc", "losses")
545
546
# Apply to SVI results
547
svi_result = svi.run(key, 1000, data)
548
extracted = getter(svi_result) # Gets params.mu.loc and losses
549
"""
550
551
def find_stack_level() -> int:
552
"""
553
Find appropriate stack level for warnings.
554
555
Helper function for issuing warnings at the correct stack level
556
in complex call hierarchies.
557
558
Returns:
559
Appropriate stack level for warnings
560
"""
561
```
562
563
## Usage Examples
564
565
```python
566
import numpyro
567
import numpyro.distributions as dist
568
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
569
import jax.numpy as jnp
570
from jax import random
571
572
# JAX configuration for optimal performance
573
def setup_jax_environment():
574
"""Configure JAX for optimal NumPyro performance."""
575
576
# Enable 64-bit precision for numerical stability
577
numpyro.enable_x64(True)
578
579
# Use GPU if available
580
numpyro.set_platform('gpu') # Falls back to CPU if GPU unavailable
581
582
# Set up multiple CPU devices for parallel chains
583
numpyro.set_host_device_count(4)
584
585
# Set random seed for reproducibility
586
numpyro.set_rng_seed(42)
587
588
print(f"JAX platform: {jax.default_backend()}")
589
print(f"JAX devices: {jax.device_count()}")
590
print(f"64-bit enabled: {jax.config.jax_enable_x64}")
591
592
# Control flow in probabilistic models
593
def control_flow_example():
594
"""Example using JAX-compatible control flow."""
595
596
def adaptive_model(x):
597
# Model switches behavior based on input
598
def simple_model(x):
599
return numpyro.sample("y", dist.Normal(x, 0.1))
600
601
def complex_model(x):
602
hidden = numpyro.sample("hidden", dist.Normal(0, 1))
603
return numpyro.sample("y", dist.Normal(x + hidden, 0.5))
604
605
# Use control flow primitive
606
is_complex = x > 0.5
607
return numpyro.cond(is_complex, x, complex_model, x, simple_model)
608
609
# Iterative sampling with while loop
610
def iterative_sampler(key, threshold=1.0):
611
def cond_fun(state):
612
_, _, total = state
613
return jnp.abs(total) < threshold
614
615
def body_fun(state):
616
step, key, total = state
617
key, subkey = random.split(key)
618
619
with handlers.seed(rng_seed=subkey):
620
new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))
621
622
return step + 1, key, total + new_sample
623
624
_, _, final_total = numpyro.while_loop(cond_fun, body_fun, (0, key, 0.0))
625
return final_total
626
627
return adaptive_model, iterative_sampler
628
629
# Memory-efficient processing
630
def large_scale_example():
631
"""Example of memory-efficient utilities for large datasets."""
632
633
# Simulate large dataset
634
n_data = 100000
635
x_large = random.normal(random.PRNGKey(0), (n_data, 50))
636
637
def expensive_transform(x_batch):
638
# Simulate expensive computation
639
return jnp.sum(x_batch ** 2, axis=1)
640
641
# Process in chunks to avoid memory issues
642
results = numpyro.soft_vmap(
643
expensive_transform,
644
x_large,
645
chunk_size=1000 # Process 1000 samples at a time
646
)
647
648
print(f"Processed {n_data} samples in chunks")
649
print(f"Result shape: {results.shape}")
650
651
# Collect results with progress tracking
652
def progressive_computation():
653
def compute_step(i, state):
654
current_sum = state
655
# Simulate computation
656
new_value = jnp.sum(results[i*1000:(i+1)*1000])
657
return current_sum + new_value, new_value
658
659
# Use fori_collect with progress bar
660
final_sum, intermediate_sums = numpyro.fori_collect(
661
0, n_data // 1000,
662
compute_step,
663
0.0,
664
progbar=True,
665
return_last_val=True
666
)
667
668
return final_sum, intermediate_sums
669
670
return progressive_computation()
671
672
# Model validation workflow
673
def validation_workflow_example():
674
"""Comprehensive model validation example."""
675
676
def potentially_problematic_model(x, y=None):
677
# Model with potential issues
678
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
679
beta = numpyro.sample("beta", dist.Normal(0, 1))
680
681
# Potential broadcasting issue
682
with numpyro.plate("data", len(x)):
683
mu = alpha + beta * x # Check shapes here
684
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
685
686
def guide(x, y=None):
687
# Variational guide
688
alpha_loc = numpyro.param("alpha_loc", 0.0)
689
alpha_scale = numpyro.param("alpha_scale", 1.0, constraint=constraints.positive)
690
beta_loc = numpyro.param("beta_loc", 0.0)
691
beta_scale = numpyro.param("beta_scale", 1.0, constraint=constraints.positive)
692
693
numpyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale))
694
numpyro.sample("beta", dist.Normal(beta_loc, beta_scale))
695
696
# Generate test data
697
x_test = jnp.linspace(0, 1, 100)
698
y_test = 1.5 + 2.0 * x_test + 0.1 * random.normal(random.PRNGKey(0), (100,))
699
700
print("=== Model Validation Report ===")
701
702
# 1. Validate model structure
703
try:
704
validation_result = numpyro.validate_model(potentially_problematic_model, x_test, y_test)
705
print("✓ Model structure validation passed")
706
print(f" Sample sites: {len(validation_result.get('sample_sites', []))}")
707
708
except Exception as e:
709
print(f"✗ Model validation failed: {e}")
710
return
711
712
# 2. Check model shapes
713
from numpyro.handlers import trace
714
715
try:
716
model_trace = trace(potentially_problematic_model).get_trace(x_test, y_test)
717
shape_info = numpyro.format_shapes(model_trace)
718
print("✓ Shape analysis:")
719
print(shape_info)
720
721
except Exception as e:
722
print(f"✗ Shape analysis failed: {e}")
723
724
# 3. Validate model-guide compatibility
725
try:
726
guide_trace = trace(guide).get_trace(x_test, y_test)
727
numpyro.check_model_guide_match(model_trace, guide_trace)
728
print("✓ Model-guide compatibility verified")
729
730
except Exception as e:
731
print(f"✗ Model-guide compatibility failed: {e}")
732
733
# 4. Test with different JAX configurations
734
original_x64 = jax.config.jax_enable_x64
735
736
for use_x64 in [False, True]:
737
numpyro.enable_x64(use_x64)
738
precision = "64-bit" if use_x64 else "32-bit"
739
740
try:
741
# Quick MCMC test
742
mcmc = MCMC(NUTS(potentially_problematic_model),
743
num_warmup=100, num_samples=100, num_chains=2)
744
mcmc.run(random.PRNGKey(0), x_test, y_test)
745
print(f"✓ {precision} MCMC test passed")
746
747
except Exception as e:
748
print(f"✗ {precision} MCMC test failed: {e}")
749
750
# Restore original precision
751
numpyro.enable_x64(original_x64)
752
753
# Performance optimization example
754
def performance_optimization_example():
755
"""Example of performance optimization utilities."""
756
757
def expensive_model(x):
758
# Model with expensive computations
759
weights = numpyro.sample("weights", dist.Normal(0, 1).expand((100, 50)))
760
761
# Expensive matrix operations
762
transformed = x @ weights.T
763
result = numpyro.sample("result", dist.Normal(transformed, 0.1))
764
return result
765
766
# Create cached version
767
@numpyro.cached_by(lambda x_shape: x_shape) # Cache by input shape
768
def compile_model(x_shape):
769
def compiled_fn(x):
770
return expensive_model(x)
771
return jit(compiled_fn)
772
773
# Use maybe_jit for conditional optimization
774
adaptive_model = numpyro.maybe_jit(expensive_model)
775
776
# Test data
777
x = random.normal(random.PRNGKey(0), (1000, 100))
778
779
print("Performance comparison:")
780
781
# Time original model
782
import time
783
start_time = time.time()
784
result1 = expensive_model(x)
785
original_time = time.time() - start_time
786
print(f"Original model: {original_time:.3f}s")
787
788
# Time cached/compiled model
789
start_time = time.time()
790
compiled_fn = compile_model(x.shape)
791
result2 = compiled_fn(x)
792
cached_time = time.time() - start_time
793
print(f"Cached/compiled: {cached_time:.3f}s")
794
795
# Time adaptive model
796
start_time = time.time()
797
result3 = adaptive_model(x)
798
adaptive_time = time.time() - start_time
799
print(f"Adaptive JIT: {adaptive_time:.3f}s")
800
801
speedup = original_time / min(cached_time, adaptive_time)
802
print(f"Speedup: {speedup:.1f}x")
803
```
804
805
## Types
806
807
```python { .api }
808
from typing import Optional, Union, Callable, Dict, Any, Tuple, ContextManager
809
from jax import Array
810
import jax.numpy as jnp
811
812
ArrayLike = Union[Array, jnp.ndarray, float, int]
813
Platform = Union["cpu", "gpu", "tpu"]
814
ProgressBarOptions = Dict[str, Any]
815
816
class ValidationResult:
817
"""Result from model validation."""
818
is_valid: bool
819
sample_sites: Dict[str, Any]
820
param_sites: Dict[str, Any]
821
deterministic_sites: Dict[str, Any]
822
warnings: list
823
errors: list
824
structure: Dict[str, Any]
825
826
class TraceInfo:
827
"""Information about model trace structure."""
828
sites: Dict[str, Any]
829
shapes: Dict[str, tuple]
830
plate_stack: list
831
dependencies: Dict[str, list]
832
833
# Control flow function types
834
CondFun = Callable[[Any], bool]
835
BodyFun = Callable[[Any], Any]
836
TrueFun = Callable[[Any], Any]
837
FalseFun = Callable[[Any], Any]
838
839
# Loop types
840
LoopState = Any
841
LoopIndex = int
842
ForBodyFun = Callable[[LoopIndex, LoopState], LoopState]
843
CollectBodyFun = Callable[[LoopIndex, LoopState], Tuple[LoopState, Any]]
844
845
# Utility types
846
CacheKey = Any
847
CacheFun = Callable[..., CacheKey]
848
TransformFun = Optional[Callable[[Any], Any]]
849
ProgressBarFun = Callable[[Callable], Callable]
850
851
# Context manager types
852
ConditionalContext = Union[ContextManager, None]
853
OptionalContext = ContextManager
854
855
# Validation types
856
ModelFun = Callable[..., Any]
857
GuideFun = Callable[..., Any]
858
TraceDict = Dict[str, Any]
859
SiteDict = Dict[str, Any]
860
```