or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

utilities.mddocs/

0

# Utilities

1

2

NumPyro provides essential utility functions for JAX configuration, control flow primitives, model validation, and development helpers. These utilities enable efficient probabilistic programming with proper hardware acceleration, memory management, and debugging capabilities.

3

4

## Capabilities

5

6

### JAX Configuration

7

8

Functions for configuring JAX behavior and hardware acceleration.

9

10

```python { .api }

11

def enable_x64(use_x64: bool = True) -> None:

12

"""

13

Enable or disable 64-bit precision for JAX computations.

14

15

By default, JAX uses 32-bit precision for performance. Enable 64-bit

16

precision when higher numerical accuracy is needed.

17

18

Args:

19

use_x64: Whether to use 64-bit precision (default: True)

20

21

Usage:

22

# Enable double precision for numerical stability

23

numpyro.enable_x64(True)

24

25

# Disable to return to 32-bit (faster but less precise)

26

numpyro.enable_x64(False)

27

28

# Check current precision

29

import jax

30

print(f"Current precision: {jax.config.jax_enable_x64}")

31

"""

32

33

def set_platform(platform: Optional[str] = None) -> None:

34

"""

35

Set the JAX platform for computations.

36

37

Args:

38

platform: Platform name ('cpu', 'gpu', 'tpu', or None for auto-detection)

39

40

Usage:

41

# Force CPU computation

42

numpyro.set_platform('cpu')

43

44

# Use GPU if available

45

numpyro.set_platform('gpu')

46

47

# Let JAX auto-detect best platform

48

numpyro.set_platform(None)

49

50

# Check current platform

51

import jax

52

print(f"Current platform: {jax.default_backend()}")

53

"""

54

55

def set_host_device_count(n: int) -> None:

56

"""

57

Set the number of CPU devices for parallel computation.

58

59

Useful for parallelizing MCMC chains across multiple CPU cores

60

when GPU is not available or desired.

61

62

Args:

63

n: Number of CPU devices to use

64

65

Usage:

66

# Use 4 CPU devices for parallel chains

67

numpyro.set_host_device_count(4)

68

69

# Then run MCMC with multiple chains

70

mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)

71

mcmc.run(rng_key, data) # Will use 4 CPU devices

72

"""

73

74

def set_rng_seed(rng_seed: Optional[int] = None) -> None:

75

"""

76

Set global random seed for reproducible results.

77

78

Args:

79

rng_seed: Random seed value (None to use system entropy)

80

81

Usage:

82

# Set seed for reproducible experiments

83

numpyro.set_rng_seed(42)

84

85

# Clear seed to use random initialization

86

numpyro.set_rng_seed(None)

87

"""

88

```

89

90

### Control Flow Primitives

91

92

JAX-compatible control flow functions for probabilistic programs.

93

94

```python { .api }

95

def cond(pred: ArrayLike, true_operand: Any, true_fun: Callable,

96

false_operand: Any, false_fun: Callable) -> Any:

97

"""

98

JAX-compatible conditional execution primitive.

99

100

Provides structured control flow that works with JAX transformations

101

like JIT compilation and automatic differentiation.

102

103

Args:

104

pred: Boolean condition for branching

105

true_operand: Operand passed to true_fun if pred is True

106

true_fun: Function to call if pred is True

107

false_operand: Operand passed to false_fun if pred is False

108

false_fun: Function to call if pred is False

109

110

Returns:

111

Result of the executed branch

112

113

Usage:

114

def model(x):

115

# Conditional model structure

116

def high_noise_model(x):

117

return numpyro.sample("y", dist.Normal(x, 2.0))

118

119

def low_noise_model(x):

120

return numpyro.sample("y", dist.Normal(x, 0.1))

121

122

# Switch based on input value

123

is_high = x > 0.5

124

return numpyro.cond(is_high, x, high_noise_model, x, low_noise_model)

125

"""

126

127

def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any:

128

"""

129

JAX-compatible while loop primitive.

130

131

Executes body_fun repeatedly while cond_fun returns True.

132

Compatible with JAX transformations.

133

134

Args:

135

cond_fun: Function that takes loop state and returns boolean

136

body_fun: Function that takes loop state and returns new state

137

init_val: Initial loop state

138

139

Returns:

140

Final loop state

141

142

Usage:

143

def iterative_sampler(key, n_steps):

144

def cond_fun(state):

145

step, _, _ = state

146

return step < n_steps

147

148

def body_fun(state):

149

step, key, samples = state

150

key, subkey = random.split(key)

151

new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))

152

return step + 1, key, samples.at[step].set(new_sample)

153

154

init_samples = jnp.zeros(n_steps)

155

_, _, final_samples = numpyro.while_loop(

156

cond_fun, body_fun, (0, key, init_samples)

157

)

158

return final_samples

159

"""

160

161

def fori_loop(lower: int, upper: int, body_fun: Callable, init_val: Any) -> Any:

162

"""

163

JAX-compatible for loop primitive.

164

165

Executes body_fun for indices from lower to upper-1.

166

167

Args:

168

lower: Starting index (inclusive)

169

upper: Ending index (exclusive)

170

body_fun: Function that takes (index, state) and returns new state

171

init_val: Initial loop state

172

173

Returns:

174

Final loop state

175

176

Usage:

177

def accumulate_samples(key, n_samples):

178

def body_fun(i, state):

179

key, total = state

180

key, subkey = random.split(key)

181

sample = random.normal(subkey)

182

return key, total + sample

183

184

key, final_total = numpyro.fori_loop(0, n_samples, body_fun, (key, 0.0))

185

return final_total / n_samples

186

"""

187

```

188

189

### Memory-Efficient Utilities

190

191

Functions for managing memory usage in large-scale computations.

192

193

```python { .api }

194

def soft_vmap(fn: Callable, xs: ArrayLike, batch_ndims: int = 1,

195

chunk_size: Optional[int] = None) -> ArrayLike:

196

"""

197

Memory-efficient vectorized map that processes data in chunks.

198

199

Alternative to jax.vmap that avoids memory issues with large datasets

200

by processing inputs in smaller chunks.

201

202

Args:

203

fn: Function to vectorize

204

xs: Input arrays to map over

205

batch_ndims: Number of batch dimensions to map over

206

chunk_size: Size of chunks to process (None for auto-selection)

207

208

Returns:

209

Vectorized results concatenated from chunks

210

211

Usage:

212

# Process large dataset without memory overflow

213

def expensive_computation(x):

214

return x @ weight_matrix # Large matrix multiplication

215

216

large_data = jnp.ones((10000, 1000)) # Would cause OOM with vmap

217

218

# Process in chunks

219

results = numpyro.soft_vmap(expensive_computation, large_data, chunk_size=100)

220

# Shape: (10000, output_dim)

221

"""

222

223

def fori_collect(lower: int, upper: int, body_fun: Callable, init_val: Any,

224

transform: Optional[Callable] = None, progbar: bool = True,

225

return_last_val: bool = False, collection_size: Optional[int] = None,

226

**progbar_opts) -> Union[tuple, ArrayLike]:

227

"""

228

For loop with collection and optional progress bar.

229

230

Collects outputs from each iteration while optionally displaying progress.

231

Useful for iterative algorithms where you need to track intermediate results.

232

233

Args:

234

lower: Starting index

235

upper: Ending index

236

body_fun: Function returning (new_state, collection_item)

237

init_val: Initial state

238

transform: Optional transform applied to collected items

239

progbar: Whether to show progress bar

240

return_last_val: Whether to return final state

241

collection_size: Pre-allocate collection array size

242

**progbar_opts: Additional progress bar options

243

244

Returns:

245

Collection of items (and optionally final state)

246

247

Usage:

248

# Collect MCMC samples with progress tracking

249

def mcmc_step(i, state):

250

key, params = state

251

key, subkey = random.split(key)

252

253

# Single MCMC step

254

new_params = mcmc_kernel_step(subkey, params)

255

256

return (key, new_params), new_params # (new_state, collect_item)

257

258

init_state = (random.PRNGKey(0), init_params)

259

samples = numpyro.fori_collect(0, 1000, mcmc_step, init_state, progbar=True)

260

"""

261

```

262

263

### Model Validation and Debugging

264

265

Utilities for validating models and debugging probabilistic programs.

266

267

```python { .api }

268

def format_shapes(trace: dict, last_site: Optional[str] = None) -> str:

269

"""

270

Format trace shapes for debugging model structure.

271

272

Provides a readable summary of all sites in a model trace with their

273

shapes, which is useful for debugging broadcasting and plate issues.

274

275

Args:

276

trace: Execution trace from model

277

last_site: Name of last site to include (None for all sites)

278

279

Returns:

280

Formatted string showing site shapes

281

282

Usage:

283

# Debug model shapes

284

from numpyro.handlers import trace

285

286

def model():

287

with numpyro.plate("batch", 10):

288

x = numpyro.sample("x", dist.Normal(0, 1)) # Should be (10,)

289

with numpyro.plate("features", 5):

290

y = numpyro.sample("y", dist.Normal(x.expand((5,)), 1)) # Should be (10, 5)

291

292

traced_model = trace(model)

293

trace_dict = traced_model()

294

295

shape_info = numpyro.format_shapes(trace_dict)

296

print(shape_info)

297

# Output:

298

# Site shapes:

299

# x: (10,)

300

# y: (10, 5)

301

"""

302

303

def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None:

304

"""

305

Validate that model and guide have compatible structure.

306

307

Ensures that the guide provides variational distributions for all

308

sample sites in the model, which is required for SVI.

309

310

Args:

311

model_trace: Trace from model execution

312

guide_trace: Trace from guide execution

313

314

Raises:

315

ValueError: If model and guide are incompatible

316

317

Usage:

318

# Validate model-guide compatibility before SVI

319

from numpyro.handlers import trace

320

321

model_trace = trace(model).get_trace(data)

322

guide_trace = trace(guide).get_trace(data)

323

324

try:

325

numpyro.check_model_guide_match(model_trace, guide_trace)

326

print("✓ Model and guide are compatible")

327

except ValueError as e:

328

print(f"✗ Compatibility error: {e}")

329

"""

330

331

def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:

332

"""

333

Comprehensive model validation and structure analysis.

334

335

Args:

336

model: Model function to validate

337

*model_args: Arguments to pass to model

338

**model_kwargs: Keyword arguments to pass to model

339

340

Returns:

341

Dictionary containing validation results and model information

342

343

Usage:

344

def my_model(x, y=None):

345

alpha = numpyro.sample("alpha", dist.Normal(0, 1))

346

with numpyro.plate("data", len(x)):

347

numpyro.sample("y", dist.Normal(alpha + x, 1), obs=y)

348

349

x_data = jnp.linspace(0, 1, 100)

350

validation = numpyro.validate_model(my_model, x_data)

351

352

print(f"Number of sample sites: {len(validation['sample_sites'])}")

353

print(f"Model structure: {validation['structure']}")

354

print(f"Validation passed: {validation['is_valid']}")

355

"""

356

```

357

358

### Development and Performance Utilities

359

360

Helper functions for development and performance optimization.

361

362

```python { .api }

363

def maybe_jit(fn: Callable, *args, **kwargs) -> Callable:

364

"""

365

Conditionally apply JIT compilation based on context.

366

367

Automatically determines whether to JIT compile based on the computational

368

context and function characteristics.

369

370

Args:

371

fn: Function to potentially JIT compile

372

*args: Arguments that would be passed to function

373

**kwargs: Keyword arguments

374

375

Returns:

376

JIT-compiled or original function

377

378

Usage:

379

# Automatically optimize based on usage pattern

380

def expensive_computation(x):

381

return jnp.sum(x ** 2)

382

383

optimized_fn = numpyro.maybe_jit(expensive_computation)

384

result = optimized_fn(large_array) # Will be JIT compiled if beneficial

385

"""

386

387

def progress_bar_factory(num_samples: int, num_chains: int = 1) -> Callable:

388

"""

389

Create progress bar decorators for iterative algorithms.

390

391

Args:

392

num_samples: Total number of samples/iterations

393

num_chains: Number of parallel chains

394

395

Returns:

396

Progress bar decorator function

397

398

Usage:

399

# Add progress bars to custom sampling loops

400

progress_bar = numpyro.progress_bar_factory(1000, num_chains=4)

401

402

@progress_bar

403

def sampling_step(i, state):

404

# Custom sampling logic

405

return new_state

406

407

# Progress will be displayed automatically

408

final_state = fori_loop(0, 1000, sampling_step, init_state)

409

"""

410

411

def cached_by(outer_fn: Callable, *keys) -> Callable:

412

"""

413

Function caching decorator with custom cache keys.

414

415

Caches function results based on specified keys to avoid recomputation

416

of expensive operations.

417

418

Args:

419

outer_fn: Function to cache

420

*keys: Keys to use for cache lookup

421

422

Returns:

423

Cached version of the function

424

425

Usage:

426

# Cache expensive model compilations

427

@numpyro.cached_by(lambda model, data_shape: (model.__name__, data_shape))

428

def compile_model(model, data_shape):

429

# Expensive JIT compilation

430

return jit(model)

431

432

compiled_model = compile_model(my_model, (100,)) # Compiled once

433

compiled_model = compile_model(my_model, (100,)) # Retrieved from cache

434

"""

435

436

def identity(x: Any, *args, **kwargs) -> Any:

437

"""

438

Identity function that returns input unchanged.

439

440

Useful as a placeholder or default function in conditional contexts.

441

442

Args:

443

x: Input value

444

*args: Ignored additional arguments

445

**kwargs: Ignored keyword arguments

446

447

Returns:

448

Input value unchanged

449

"""

450

451

def not_jax_tracer(x: Any) -> bool:

452

"""

453

Check if value is not a JAX tracer.

454

455

Useful for conditional logic that depends on whether values are

456

concrete or abstract (traced) in JAX transformations.

457

458

Args:

459

x: Value to check

460

461

Returns:

462

True if x is not a JAX tracer, False otherwise

463

464

Usage:

465

def conditional_computation(x):

466

if numpyro.not_jax_tracer(x):

467

# This branch only executes with concrete values

468

print(f"Concrete value: {x}")

469

return x ** 2

470

"""

471

472

def is_prng_key(key: Any) -> bool:

473

"""

474

Validate that input is a proper PRNG key.

475

476

Args:

477

key: Potential PRNG key to validate

478

479

Returns:

480

True if key is a valid PRNG key

481

482

Usage:

483

from jax import random

484

485

key = random.PRNGKey(0)

486

if numpyro.is_prng_key(key):

487

subkey = random.split(key)[0]

488

else:

489

raise ValueError("Invalid PRNG key")

490

"""

491

```

492

493

### Context Managers and Control

494

495

Utilities for context management and execution control.

496

497

```python { .api }

498

def optional(condition: bool, context_manager: Any) -> Any:

499

"""

500

Conditionally apply a context manager.

501

502

Args:

503

condition: Whether to apply the context manager

504

context_manager: Context manager to apply if condition is True

505

506

Returns:

507

Context manager or no-op context

508

509

Usage:

510

# Conditionally enable validation

511

use_validation = True

512

513

with numpyro.optional(use_validation, numpyro.validation_enabled()):

514

result = model() # Validation applied only if use_validation=True

515

"""

516

517

def control_flow_prims_disabled() -> bool:

518

"""

519

Check if control flow primitives are disabled.

520

521

Returns:

522

True if control flow primitives (cond, while_loop) are disabled

523

524

Usage:

525

if numpyro.control_flow_prims_disabled():

526

# Use alternative implementation without control flow

527

result = alternative_implementation()

528

else:

529

result = numpyro.cond(pred, true_op, true_fn, false_op, false_fn)

530

"""

531

532

def nested_attrgetter(*collect_fields: str) -> Callable:

533

"""

534

Create getter for nested attributes in complex data structures.

535

536

Args:

537

*collect_fields: Dot-separated field paths to extract

538

539

Returns:

540

Function that extracts specified fields from objects

541

542

Usage:

543

# Extract nested fields from complex results

544

getter = numpyro.nested_attrgetter("params.mu.loc", "losses")

545

546

# Apply to SVI results

547

svi_result = svi.run(key, 1000, data)

548

extracted = getter(svi_result) # Gets params.mu.loc and losses

549

"""

550

551

def find_stack_level() -> int:

552

"""

553

Find appropriate stack level for warnings.

554

555

Helper function for issuing warnings at the correct stack level

556

in complex call hierarchies.

557

558

Returns:

559

Appropriate stack level for warnings

560

"""

561

```

562

563

## Usage Examples

564

565

```python

566

import numpyro

567

import numpyro.distributions as dist

568

from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO

569

import jax.numpy as jnp

570

from jax import random

571

572

# JAX configuration for optimal performance

573

def setup_jax_environment():

574

"""Configure JAX for optimal NumPyro performance."""

575

576

# Enable 64-bit precision for numerical stability

577

numpyro.enable_x64(True)

578

579

# Use GPU if available

580

numpyro.set_platform('gpu') # Falls back to CPU if GPU unavailable

581

582

# Set up multiple CPU devices for parallel chains

583

numpyro.set_host_device_count(4)

584

585

# Set random seed for reproducibility

586

numpyro.set_rng_seed(42)

587

588

print(f"JAX platform: {jax.default_backend()}")

589

print(f"JAX devices: {jax.device_count()}")

590

print(f"64-bit enabled: {jax.config.jax_enable_x64}")

591

592

# Control flow in probabilistic models

593

def control_flow_example():

594

"""Example using JAX-compatible control flow."""

595

596

def adaptive_model(x):

597

# Model switches behavior based on input

598

def simple_model(x):

599

return numpyro.sample("y", dist.Normal(x, 0.1))

600

601

def complex_model(x):

602

hidden = numpyro.sample("hidden", dist.Normal(0, 1))

603

return numpyro.sample("y", dist.Normal(x + hidden, 0.5))

604

605

# Use control flow primitive

606

is_complex = x > 0.5

607

return numpyro.cond(is_complex, x, complex_model, x, simple_model)

608

609

# Iterative sampling with while loop

610

def iterative_sampler(key, threshold=1.0):

611

def cond_fun(state):

612

_, _, total = state

613

return jnp.abs(total) < threshold

614

615

def body_fun(state):

616

step, key, total = state

617

key, subkey = random.split(key)

618

619

with handlers.seed(rng_seed=subkey):

620

new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))

621

622

return step + 1, key, total + new_sample

623

624

_, _, final_total = numpyro.while_loop(cond_fun, body_fun, (0, key, 0.0))

625

return final_total

626

627

return adaptive_model, iterative_sampler

628

629

# Memory-efficient processing

630

def large_scale_example():

631

"""Example of memory-efficient utilities for large datasets."""

632

633

# Simulate large dataset

634

n_data = 100000

635

x_large = random.normal(random.PRNGKey(0), (n_data, 50))

636

637

def expensive_transform(x_batch):

638

# Simulate expensive computation

639

return jnp.sum(x_batch ** 2, axis=1)

640

641

# Process in chunks to avoid memory issues

642

results = numpyro.soft_vmap(

643

expensive_transform,

644

x_large,

645

chunk_size=1000 # Process 1000 samples at a time

646

)

647

648

print(f"Processed {n_data} samples in chunks")

649

print(f"Result shape: {results.shape}")

650

651

# Collect results with progress tracking

652

def progressive_computation():

653

def compute_step(i, state):

654

current_sum = state

655

# Simulate computation

656

new_value = jnp.sum(results[i*1000:(i+1)*1000])

657

return current_sum + new_value, new_value

658

659

# Use fori_collect with progress bar

660

final_sum, intermediate_sums = numpyro.fori_collect(

661

0, n_data // 1000,

662

compute_step,

663

0.0,

664

progbar=True,

665

return_last_val=True

666

)

667

668

return final_sum, intermediate_sums

669

670

return progressive_computation()

671

672

# Model validation workflow

673

def validation_workflow_example():

674

"""Comprehensive model validation example."""

675

676

def potentially_problematic_model(x, y=None):

677

# Model with potential issues

678

alpha = numpyro.sample("alpha", dist.Normal(0, 1))

679

beta = numpyro.sample("beta", dist.Normal(0, 1))

680

681

# Potential broadcasting issue

682

with numpyro.plate("data", len(x)):

683

mu = alpha + beta * x # Check shapes here

684

numpyro.sample("y", dist.Normal(mu, 1), obs=y)

685

686

def guide(x, y=None):

687

# Variational guide

688

alpha_loc = numpyro.param("alpha_loc", 0.0)

689

alpha_scale = numpyro.param("alpha_scale", 1.0, constraint=constraints.positive)

690

beta_loc = numpyro.param("beta_loc", 0.0)

691

beta_scale = numpyro.param("beta_scale", 1.0, constraint=constraints.positive)

692

693

numpyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale))

694

numpyro.sample("beta", dist.Normal(beta_loc, beta_scale))

695

696

# Generate test data

697

x_test = jnp.linspace(0, 1, 100)

698

y_test = 1.5 + 2.0 * x_test + 0.1 * random.normal(random.PRNGKey(0), (100,))

699

700

print("=== Model Validation Report ===")

701

702

# 1. Validate model structure

703

try:

704

validation_result = numpyro.validate_model(potentially_problematic_model, x_test, y_test)

705

print("✓ Model structure validation passed")

706

print(f" Sample sites: {len(validation_result.get('sample_sites', []))}")

707

708

except Exception as e:

709

print(f"✗ Model validation failed: {e}")

710

return

711

712

# 2. Check model shapes

713

from numpyro.handlers import trace

714

715

try:

716

model_trace = trace(potentially_problematic_model).get_trace(x_test, y_test)

717

shape_info = numpyro.format_shapes(model_trace)

718

print("✓ Shape analysis:")

719

print(shape_info)

720

721

except Exception as e:

722

print(f"✗ Shape analysis failed: {e}")

723

724

# 3. Validate model-guide compatibility

725

try:

726

guide_trace = trace(guide).get_trace(x_test, y_test)

727

numpyro.check_model_guide_match(model_trace, guide_trace)

728

print("✓ Model-guide compatibility verified")

729

730

except Exception as e:

731

print(f"✗ Model-guide compatibility failed: {e}")

732

733

# 4. Test with different JAX configurations

734

original_x64 = jax.config.jax_enable_x64

735

736

for use_x64 in [False, True]:

737

numpyro.enable_x64(use_x64)

738

precision = "64-bit" if use_x64 else "32-bit"

739

740

try:

741

# Quick MCMC test

742

mcmc = MCMC(NUTS(potentially_problematic_model),

743

num_warmup=100, num_samples=100, num_chains=2)

744

mcmc.run(random.PRNGKey(0), x_test, y_test)

745

print(f"✓ {precision} MCMC test passed")

746

747

except Exception as e:

748

print(f"✗ {precision} MCMC test failed: {e}")

749

750

# Restore original precision

751

numpyro.enable_x64(original_x64)

752

753

# Performance optimization example

754

def performance_optimization_example():

755

"""Example of performance optimization utilities."""

756

757

def expensive_model(x):

758

# Model with expensive computations

759

weights = numpyro.sample("weights", dist.Normal(0, 1).expand((100, 50)))

760

761

# Expensive matrix operations

762

transformed = x @ weights.T

763

result = numpyro.sample("result", dist.Normal(transformed, 0.1))

764

return result

765

766

# Create cached version

767

@numpyro.cached_by(lambda x_shape: x_shape) # Cache by input shape

768

def compile_model(x_shape):

769

def compiled_fn(x):

770

return expensive_model(x)

771

return jit(compiled_fn)

772

773

# Use maybe_jit for conditional optimization

774

adaptive_model = numpyro.maybe_jit(expensive_model)

775

776

# Test data

777

x = random.normal(random.PRNGKey(0), (1000, 100))

778

779

print("Performance comparison:")

780

781

# Time original model

782

import time

783

start_time = time.time()

784

result1 = expensive_model(x)

785

original_time = time.time() - start_time

786

print(f"Original model: {original_time:.3f}s")

787

788

# Time cached/compiled model

789

start_time = time.time()

790

compiled_fn = compile_model(x.shape)

791

result2 = compiled_fn(x)

792

cached_time = time.time() - start_time

793

print(f"Cached/compiled: {cached_time:.3f}s")

794

795

# Time adaptive model

796

start_time = time.time()

797

result3 = adaptive_model(x)

798

adaptive_time = time.time() - start_time

799

print(f"Adaptive JIT: {adaptive_time:.3f}s")

800

801

speedup = original_time / min(cached_time, adaptive_time)

802

print(f"Speedup: {speedup:.1f}x")

803

```

804

805

## Types

806

807

```python { .api }

808

from typing import Optional, Union, Callable, Dict, Any, Tuple, ContextManager

809

from jax import Array

810

import jax.numpy as jnp

811

812

ArrayLike = Union[Array, jnp.ndarray, float, int]

813

Platform = Union["cpu", "gpu", "tpu"]

814

ProgressBarOptions = Dict[str, Any]

815

816

class ValidationResult:

817

"""Result from model validation."""

818

is_valid: bool

819

sample_sites: Dict[str, Any]

820

param_sites: Dict[str, Any]

821

deterministic_sites: Dict[str, Any]

822

warnings: list

823

errors: list

824

structure: Dict[str, Any]

825

826

class TraceInfo:

827

"""Information about model trace structure."""

828

sites: Dict[str, Any]

829

shapes: Dict[str, tuple]

830

plate_stack: list

831

dependencies: Dict[str, list]

832

833

# Control flow function types

834

CondFun = Callable[[Any], bool]

835

BodyFun = Callable[[Any], Any]

836

TrueFun = Callable[[Any], Any]

837

FalseFun = Callable[[Any], Any]

838

839

# Loop types

840

LoopState = Any

841

LoopIndex = int

842

ForBodyFun = Callable[[LoopIndex, LoopState], LoopState]

843

CollectBodyFun = Callable[[LoopIndex, LoopState], Tuple[LoopState, Any]]

844

845

# Utility types

846

CacheKey = Any

847

CacheFun = Callable[..., CacheKey]

848

TransformFun = Optional[Callable[[Any], Any]]

849

ProgressBarFun = Callable[[Callable], Callable]

850

851

# Context manager types

852

ConditionalContext = Union[ContextManager, None]

853

OptionalContext = ContextManager

854

855

# Validation types

856

ModelFun = Callable[..., Any]

857

GuideFun = Callable[..., Any]

858

TraceDict = Dict[str, Any]

859

SiteDict = Dict[str, Any]

860

```