CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-numpyro

Lightweight probabilistic programming library with NumPy backend for Pyro, powered by JAX for automatic differentiation and JIT compilation.

Pending
Overview
Eval results
Files

utilities.mddocs/

Utilities

NumPyro provides essential utility functions for JAX configuration, control flow primitives, model validation, and development helpers. These utilities enable efficient probabilistic programming with proper hardware acceleration, memory management, and debugging capabilities.

Capabilities

JAX Configuration

Functions for configuring JAX behavior and hardware acceleration.

def enable_x64(use_x64: bool = True) -> None:
    """
    Enable or disable 64-bit precision for JAX computations.
    
    By default, JAX uses 32-bit precision for performance. Enable 64-bit
    precision when higher numerical accuracy is needed.
    
    Args:
        use_x64: Whether to use 64-bit precision (default: True)
        
    Usage:
        # Enable double precision for numerical stability
        numpyro.enable_x64(True)
        
        # Disable to return to 32-bit (faster but less precise)
        numpyro.enable_x64(False)
        
        # Check current precision
        import jax
        print(f"Current precision: {jax.config.jax_enable_x64}")
    """

def set_platform(platform: Optional[str] = None) -> None:
    """
    Set the JAX platform for computations.
    
    Args:
        platform: Platform name ('cpu', 'gpu', 'tpu', or None for auto-detection)
        
    Usage:
        # Force CPU computation
        numpyro.set_platform('cpu')
        
        # Use GPU if available
        numpyro.set_platform('gpu')
        
        # Let JAX auto-detect best platform
        numpyro.set_platform(None)
        
        # Check current platform
        import jax
        print(f"Current platform: {jax.default_backend()}")
    """

def set_host_device_count(n: int) -> None:
    """
    Set the number of CPU devices for parallel computation.
    
    Useful for parallelizing MCMC chains across multiple CPU cores
    when GPU is not available or desired.
    
    Args:
        n: Number of CPU devices to use
        
    Usage:
        # Use 4 CPU devices for parallel chains
        numpyro.set_host_device_count(4)
        
        # Then run MCMC with multiple chains
        mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=4)
        mcmc.run(rng_key, data)  # Will use 4 CPU devices
    """

def set_rng_seed(rng_seed: Optional[int] = None) -> None:
    """
    Set global random seed for reproducible results.
    
    Args:
        rng_seed: Random seed value (None to use system entropy)
        
    Usage:
        # Set seed for reproducible experiments
        numpyro.set_rng_seed(42)
        
        # Clear seed to use random initialization
        numpyro.set_rng_seed(None)
    """

Control Flow Primitives

JAX-compatible control flow functions for probabilistic programs.

def cond(pred: ArrayLike, true_operand: Any, true_fun: Callable, 
         false_operand: Any, false_fun: Callable) -> Any:
    """
    JAX-compatible conditional execution primitive.
    
    Provides structured control flow that works with JAX transformations
    like JIT compilation and automatic differentiation.
    
    Args:
        pred: Boolean condition for branching
        true_operand: Operand passed to true_fun if pred is True
        true_fun: Function to call if pred is True
        false_operand: Operand passed to false_fun if pred is False
        false_fun: Function to call if pred is False
        
    Returns:
        Result of the executed branch
        
    Usage:
        def model(x):
            # Conditional model structure
            def high_noise_model(x):
                return numpyro.sample("y", dist.Normal(x, 2.0))
            
            def low_noise_model(x):
                return numpyro.sample("y", dist.Normal(x, 0.1))
            
            # Switch based on input value
            is_high = x > 0.5
            return numpyro.cond(is_high, x, high_noise_model, x, low_noise_model)
    """

def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any:
    """
    JAX-compatible while loop primitive.
    
    Executes body_fun repeatedly while cond_fun returns True.
    Compatible with JAX transformations.
    
    Args:
        cond_fun: Function that takes loop state and returns boolean
        body_fun: Function that takes loop state and returns new state
        init_val: Initial loop state
        
    Returns:
        Final loop state
        
    Usage:
        def iterative_sampler(key, n_steps):
            def cond_fun(state):
                step, _, _ = state
                return step < n_steps
            
            def body_fun(state):
                step, key, samples = state
                key, subkey = random.split(key)
                new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))
                return step + 1, key, samples.at[step].set(new_sample)
            
            init_samples = jnp.zeros(n_steps)
            _, _, final_samples = numpyro.while_loop(
                cond_fun, body_fun, (0, key, init_samples)
            )
            return final_samples
    """

def fori_loop(lower: int, upper: int, body_fun: Callable, init_val: Any) -> Any:
    """
    JAX-compatible for loop primitive.
    
    Executes body_fun for indices from lower to upper-1.
    
    Args:
        lower: Starting index (inclusive)
        upper: Ending index (exclusive)
        body_fun: Function that takes (index, state) and returns new state
        init_val: Initial loop state
        
    Returns:
        Final loop state
        
    Usage:
        def accumulate_samples(key, n_samples):
            def body_fun(i, state):
                key, total = state
                key, subkey = random.split(key)
                sample = random.normal(subkey)
                return key, total + sample
            
            key, final_total = numpyro.fori_loop(0, n_samples, body_fun, (key, 0.0))
            return final_total / n_samples
    """

Memory-Efficient Utilities

Functions for managing memory usage in large-scale computations.

def soft_vmap(fn: Callable, xs: ArrayLike, batch_ndims: int = 1, 
             chunk_size: Optional[int] = None) -> ArrayLike:
    """
    Memory-efficient vectorized map that processes data in chunks.
    
    Alternative to jax.vmap that avoids memory issues with large datasets
    by processing inputs in smaller chunks.
    
    Args:
        fn: Function to vectorize
        xs: Input arrays to map over
        batch_ndims: Number of batch dimensions to map over
        chunk_size: Size of chunks to process (None for auto-selection)
        
    Returns:
        Vectorized results concatenated from chunks
        
    Usage:
        # Process large dataset without memory overflow
        def expensive_computation(x):
            return x @ weight_matrix  # Large matrix multiplication
        
        large_data = jnp.ones((10000, 1000))  # Would cause OOM with vmap
        
        # Process in chunks
        results = numpyro.soft_vmap(expensive_computation, large_data, chunk_size=100)
        # Shape: (10000, output_dim)
    """

def fori_collect(lower: int, upper: int, body_fun: Callable, init_val: Any,
                transform: Optional[Callable] = None, progbar: bool = True,
                return_last_val: bool = False, collection_size: Optional[int] = None,
                **progbar_opts) -> Union[tuple, ArrayLike]:
    """
    For loop with collection and optional progress bar.
    
    Collects outputs from each iteration while optionally displaying progress.
    Useful for iterative algorithms where you need to track intermediate results.
    
    Args:
        lower: Starting index
        upper: Ending index  
        body_fun: Function returning (new_state, collection_item)
        init_val: Initial state
        transform: Optional transform applied to collected items
        progbar: Whether to show progress bar
        return_last_val: Whether to return final state
        collection_size: Pre-allocate collection array size
        **progbar_opts: Additional progress bar options
        
    Returns:
        Collection of items (and optionally final state)
        
    Usage:
        # Collect MCMC samples with progress tracking
        def mcmc_step(i, state):
            key, params = state
            key, subkey = random.split(key)
            
            # Single MCMC step
            new_params = mcmc_kernel_step(subkey, params)
            
            return (key, new_params), new_params  # (new_state, collect_item)
        
        init_state = (random.PRNGKey(0), init_params)
        samples = numpyro.fori_collect(0, 1000, mcmc_step, init_state, progbar=True)
    """

Model Validation and Debugging

Utilities for validating models and debugging probabilistic programs.

def format_shapes(trace: dict, last_site: Optional[str] = None) -> str:
    """
    Format trace shapes for debugging model structure.
    
    Provides a readable summary of all sites in a model trace with their
    shapes, which is useful for debugging broadcasting and plate issues.
    
    Args:
        trace: Execution trace from model
        last_site: Name of last site to include (None for all sites)
        
    Returns:
        Formatted string showing site shapes
        
    Usage:
        # Debug model shapes
        from numpyro.handlers import trace
        
        def model():
            with numpyro.plate("batch", 10):
                x = numpyro.sample("x", dist.Normal(0, 1))  # Should be (10,)
                with numpyro.plate("features", 5):
                    y = numpyro.sample("y", dist.Normal(x.expand((5,)), 1))  # Should be (10, 5)
        
        traced_model = trace(model)
        trace_dict = traced_model()
        
        shape_info = numpyro.format_shapes(trace_dict)
        print(shape_info)
        # Output:
        # Site shapes:
        #   x: (10,)
        #   y: (10, 5)
    """

def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None:
    """
    Validate that model and guide have compatible structure.
    
    Ensures that the guide provides variational distributions for all
    sample sites in the model, which is required for SVI.
    
    Args:
        model_trace: Trace from model execution
        guide_trace: Trace from guide execution
        
    Raises:
        ValueError: If model and guide are incompatible
        
    Usage:
        # Validate model-guide compatibility before SVI
        from numpyro.handlers import trace
        
        model_trace = trace(model).get_trace(data)
        guide_trace = trace(guide).get_trace(data)
        
        try:
            numpyro.check_model_guide_match(model_trace, guide_trace)
            print("✓ Model and guide are compatible")
        except ValueError as e:
            print(f"✗ Compatibility error: {e}")
    """

def validate_model(model: Callable, *model_args, **model_kwargs) -> dict:
    """
    Comprehensive model validation and structure analysis.
    
    Args:
        model: Model function to validate
        *model_args: Arguments to pass to model
        **model_kwargs: Keyword arguments to pass to model
        
    Returns:
        Dictionary containing validation results and model information
        
    Usage:
        def my_model(x, y=None):
            alpha = numpyro.sample("alpha", dist.Normal(0, 1))
            with numpyro.plate("data", len(x)):
                numpyro.sample("y", dist.Normal(alpha + x, 1), obs=y)
        
        x_data = jnp.linspace(0, 1, 100)
        validation = numpyro.validate_model(my_model, x_data)
        
        print(f"Number of sample sites: {len(validation['sample_sites'])}")
        print(f"Model structure: {validation['structure']}")
        print(f"Validation passed: {validation['is_valid']}")
    """

Development and Performance Utilities

Helper functions for development and performance optimization.

def maybe_jit(fn: Callable, *args, **kwargs) -> Callable:
    """
    Conditionally apply JIT compilation based on context.
    
    Automatically determines whether to JIT compile based on the computational
    context and function characteristics.
    
    Args:
        fn: Function to potentially JIT compile
        *args: Arguments that would be passed to function
        **kwargs: Keyword arguments
        
    Returns:
        JIT-compiled or original function
        
    Usage:
        # Automatically optimize based on usage pattern
        def expensive_computation(x):
            return jnp.sum(x ** 2)
        
        optimized_fn = numpyro.maybe_jit(expensive_computation)
        result = optimized_fn(large_array)  # Will be JIT compiled if beneficial
    """

def progress_bar_factory(num_samples: int, num_chains: int = 1) -> Callable:
    """
    Create progress bar decorators for iterative algorithms.
    
    Args:
        num_samples: Total number of samples/iterations
        num_chains: Number of parallel chains
        
    Returns:
        Progress bar decorator function
        
    Usage:
        # Add progress bars to custom sampling loops
        progress_bar = numpyro.progress_bar_factory(1000, num_chains=4)
        
        @progress_bar
        def sampling_step(i, state):
            # Custom sampling logic
            return new_state
        
        # Progress will be displayed automatically
        final_state = fori_loop(0, 1000, sampling_step, init_state)
    """

def cached_by(outer_fn: Callable, *keys) -> Callable:
    """
    Function caching decorator with custom cache keys.
    
    Caches function results based on specified keys to avoid recomputation
    of expensive operations.
    
    Args:
        outer_fn: Function to cache
        *keys: Keys to use for cache lookup
        
    Returns:
        Cached version of the function
        
    Usage:
        # Cache expensive model compilations
        @numpyro.cached_by(lambda model, data_shape: (model.__name__, data_shape))
        def compile_model(model, data_shape):
            # Expensive JIT compilation
            return jit(model)
        
        compiled_model = compile_model(my_model, (100,))  # Compiled once
        compiled_model = compile_model(my_model, (100,))  # Retrieved from cache
    """

def identity(x: Any, *args, **kwargs) -> Any:
    """
    Identity function that returns input unchanged.
    
    Useful as a placeholder or default function in conditional contexts.
    
    Args:
        x: Input value
        *args: Ignored additional arguments  
        **kwargs: Ignored keyword arguments
        
    Returns:
        Input value unchanged
    """

def not_jax_tracer(x: Any) -> bool:
    """
    Check if value is not a JAX tracer.
    
    Useful for conditional logic that depends on whether values are
    concrete or abstract (traced) in JAX transformations.
    
    Args:
        x: Value to check
        
    Returns:
        True if x is not a JAX tracer, False otherwise
        
    Usage:
        def conditional_computation(x):
            if numpyro.not_jax_tracer(x):
                # This branch only executes with concrete values
                print(f"Concrete value: {x}")
            return x ** 2
    """

def is_prng_key(key: Any) -> bool:
    """
    Validate that input is a proper PRNG key.
    
    Args:
        key: Potential PRNG key to validate
        
    Returns:
        True if key is a valid PRNG key
        
    Usage:
        from jax import random
        
        key = random.PRNGKey(0)
        if numpyro.is_prng_key(key):
            subkey = random.split(key)[0]
        else:
            raise ValueError("Invalid PRNG key")
    """

Context Managers and Control

Utilities for context management and execution control.

def optional(condition: bool, context_manager: Any) -> Any:
    """
    Conditionally apply a context manager.
    
    Args:
        condition: Whether to apply the context manager
        context_manager: Context manager to apply if condition is True
        
    Returns:
        Context manager or no-op context
        
    Usage:
        # Conditionally enable validation
        use_validation = True
        
        with numpyro.optional(use_validation, numpyro.validation_enabled()):
            result = model()  # Validation applied only if use_validation=True
    """

def control_flow_prims_disabled() -> bool:
    """
    Check if control flow primitives are disabled.
    
    Returns:
        True if control flow primitives (cond, while_loop) are disabled
        
    Usage:
        if numpyro.control_flow_prims_disabled():
            # Use alternative implementation without control flow
            result = alternative_implementation()
        else:
            result = numpyro.cond(pred, true_op, true_fn, false_op, false_fn)
    """

def nested_attrgetter(*collect_fields: str) -> Callable:
    """
    Create getter for nested attributes in complex data structures.
    
    Args:
        *collect_fields: Dot-separated field paths to extract
        
    Returns:
        Function that extracts specified fields from objects
        
    Usage:
        # Extract nested fields from complex results
        getter = numpyro.nested_attrgetter("params.mu.loc", "losses")
        
        # Apply to SVI results
        svi_result = svi.run(key, 1000, data)
        extracted = getter(svi_result)  # Gets params.mu.loc and losses
    """

def find_stack_level() -> int:
    """
    Find appropriate stack level for warnings.
    
    Helper function for issuing warnings at the correct stack level
    in complex call hierarchies.
    
    Returns:
        Appropriate stack level for warnings
    """

Usage Examples

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
import jax.numpy as jnp
from jax import random

# JAX configuration for optimal performance
def setup_jax_environment():
    """Configure JAX for optimal NumPyro performance."""
    
    # Enable 64-bit precision for numerical stability
    numpyro.enable_x64(True)
    
    # Use GPU if available
    numpyro.set_platform('gpu')  # Falls back to CPU if GPU unavailable
    
    # Set up multiple CPU devices for parallel chains
    numpyro.set_host_device_count(4)
    
    # Set random seed for reproducibility
    numpyro.set_rng_seed(42)
    
    print(f"JAX platform: {jax.default_backend()}")
    print(f"JAX devices: {jax.device_count()}")
    print(f"64-bit enabled: {jax.config.jax_enable_x64}")

# Control flow in probabilistic models
def control_flow_example():
    """Example using JAX-compatible control flow."""
    
    def adaptive_model(x):
        # Model switches behavior based on input
        def simple_model(x):
            return numpyro.sample("y", dist.Normal(x, 0.1))
        
        def complex_model(x):
            hidden = numpyro.sample("hidden", dist.Normal(0, 1))
            return numpyro.sample("y", dist.Normal(x + hidden, 0.5))
        
        # Use control flow primitive
        is_complex = x > 0.5
        return numpyro.cond(is_complex, x, complex_model, x, simple_model)
    
    # Iterative sampling with while loop
    def iterative_sampler(key, threshold=1.0):
        def cond_fun(state):
            _, _, total = state
            return jnp.abs(total) < threshold
        
        def body_fun(state):
            step, key, total = state
            key, subkey = random.split(key)
            
            with handlers.seed(rng_seed=subkey):
                new_sample = numpyro.sample(f"x_{step}", dist.Normal(0, 1))
            
            return step + 1, key, total + new_sample
        
        _, _, final_total = numpyro.while_loop(cond_fun, body_fun, (0, key, 0.0))
        return final_total
    
    return adaptive_model, iterative_sampler

# Memory-efficient processing
def large_scale_example():
    """Example of memory-efficient utilities for large datasets."""
    
    # Simulate large dataset
    n_data = 100000
    x_large = random.normal(random.PRNGKey(0), (n_data, 50))
    
    def expensive_transform(x_batch):
        # Simulate expensive computation
        return jnp.sum(x_batch ** 2, axis=1)
    
    # Process in chunks to avoid memory issues
    results = numpyro.soft_vmap(
        expensive_transform, 
        x_large, 
        chunk_size=1000  # Process 1000 samples at a time
    )
    
    print(f"Processed {n_data} samples in chunks")
    print(f"Result shape: {results.shape}")
    
    # Collect results with progress tracking
    def progressive_computation():
        def compute_step(i, state):
            current_sum = state
            # Simulate computation
            new_value = jnp.sum(results[i*1000:(i+1)*1000])
            return current_sum + new_value, new_value
        
        # Use fori_collect with progress bar
        final_sum, intermediate_sums = numpyro.fori_collect(
            0, n_data // 1000,
            compute_step,
            0.0,
            progbar=True,
            return_last_val=True
        )
        
        return final_sum, intermediate_sums
    
    return progressive_computation()

# Model validation workflow
def validation_workflow_example():
    """Comprehensive model validation example."""
    
    def potentially_problematic_model(x, y=None):
        # Model with potential issues
        alpha = numpyro.sample("alpha", dist.Normal(0, 1))
        beta = numpyro.sample("beta", dist.Normal(0, 1))
        
        # Potential broadcasting issue
        with numpyro.plate("data", len(x)):
            mu = alpha + beta * x  # Check shapes here
            numpyro.sample("y", dist.Normal(mu, 1), obs=y)
    
    def guide(x, y=None):
        # Variational guide
        alpha_loc = numpyro.param("alpha_loc", 0.0)
        alpha_scale = numpyro.param("alpha_scale", 1.0, constraint=constraints.positive)
        beta_loc = numpyro.param("beta_loc", 0.0)
        beta_scale = numpyro.param("beta_scale", 1.0, constraint=constraints.positive)
        
        numpyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale))
        numpyro.sample("beta", dist.Normal(beta_loc, beta_scale))
    
    # Generate test data
    x_test = jnp.linspace(0, 1, 100)
    y_test = 1.5 + 2.0 * x_test + 0.1 * random.normal(random.PRNGKey(0), (100,))
    
    print("=== Model Validation Report ===")
    
    # 1. Validate model structure
    try:
        validation_result = numpyro.validate_model(potentially_problematic_model, x_test, y_test)
        print("✓ Model structure validation passed")
        print(f"  Sample sites: {len(validation_result.get('sample_sites', []))}")
        
    except Exception as e:
        print(f"✗ Model validation failed: {e}")
        return
    
    # 2. Check model shapes
    from numpyro.handlers import trace
    
    try:
        model_trace = trace(potentially_problematic_model).get_trace(x_test, y_test)
        shape_info = numpyro.format_shapes(model_trace)
        print("✓ Shape analysis:")
        print(shape_info)
        
    except Exception as e:
        print(f"✗ Shape analysis failed: {e}")
    
    # 3. Validate model-guide compatibility
    try:
        guide_trace = trace(guide).get_trace(x_test, y_test)
        numpyro.check_model_guide_match(model_trace, guide_trace)
        print("✓ Model-guide compatibility verified")
        
    except Exception as e:
        print(f"✗ Model-guide compatibility failed: {e}")
    
    # 4. Test with different JAX configurations
    original_x64 = jax.config.jax_enable_x64
    
    for use_x64 in [False, True]:
        numpyro.enable_x64(use_x64)
        precision = "64-bit" if use_x64 else "32-bit"
        
        try:
            # Quick MCMC test
            mcmc = MCMC(NUTS(potentially_problematic_model), 
                       num_warmup=100, num_samples=100, num_chains=2)
            mcmc.run(random.PRNGKey(0), x_test, y_test)
            print(f"✓ {precision} MCMC test passed")
            
        except Exception as e:
            print(f"✗ {precision} MCMC test failed: {e}")
    
    # Restore original precision
    numpyro.enable_x64(original_x64)

# Performance optimization example
def performance_optimization_example():
    """Example of performance optimization utilities."""
    
    def expensive_model(x):
        # Model with expensive computations
        weights = numpyro.sample("weights", dist.Normal(0, 1).expand((100, 50)))
        
        # Expensive matrix operations
        transformed = x @ weights.T
        result = numpyro.sample("result", dist.Normal(transformed, 0.1))
        return result
    
    # Create cached version
    @numpyro.cached_by(lambda x_shape: x_shape)  # Cache by input shape
    def compile_model(x_shape):
        def compiled_fn(x):
            return expensive_model(x)
        return jit(compiled_fn)
    
    # Use maybe_jit for conditional optimization
    adaptive_model = numpyro.maybe_jit(expensive_model)
    
    # Test data
    x = random.normal(random.PRNGKey(0), (1000, 100))
    
    print("Performance comparison:")
    
    # Time original model
    import time
    start_time = time.time()
    result1 = expensive_model(x)
    original_time = time.time() - start_time
    print(f"Original model: {original_time:.3f}s")
    
    # Time cached/compiled model
    start_time = time.time()
    compiled_fn = compile_model(x.shape)
    result2 = compiled_fn(x)
    cached_time = time.time() - start_time
    print(f"Cached/compiled: {cached_time:.3f}s")
    
    # Time adaptive model
    start_time = time.time()
    result3 = adaptive_model(x)
    adaptive_time = time.time() - start_time
    print(f"Adaptive JIT: {adaptive_time:.3f}s")
    
    speedup = original_time / min(cached_time, adaptive_time)
    print(f"Speedup: {speedup:.1f}x")

Types

from typing import Optional, Union, Callable, Dict, Any, Tuple, ContextManager
from jax import Array
import jax.numpy as jnp

ArrayLike = Union[Array, jnp.ndarray, float, int]
Platform = Union["cpu", "gpu", "tpu"]
ProgressBarOptions = Dict[str, Any]

class ValidationResult:
    """Result from model validation."""
    is_valid: bool
    sample_sites: Dict[str, Any]
    param_sites: Dict[str, Any] 
    deterministic_sites: Dict[str, Any]
    warnings: list
    errors: list
    structure: Dict[str, Any]

class TraceInfo:
    """Information about model trace structure."""
    sites: Dict[str, Any]
    shapes: Dict[str, tuple]
    plate_stack: list
    dependencies: Dict[str, list]

# Control flow function types
CondFun = Callable[[Any], bool]
BodyFun = Callable[[Any], Any] 
TrueFun = Callable[[Any], Any]
FalseFun = Callable[[Any], Any]

# Loop types
LoopState = Any
LoopIndex = int
ForBodyFun = Callable[[LoopIndex, LoopState], LoopState]
CollectBodyFun = Callable[[LoopIndex, LoopState], Tuple[LoopState, Any]]

# Utility types
CacheKey = Any
CacheFun = Callable[..., CacheKey]
TransformFun = Optional[Callable[[Any], Any]]
ProgressBarFun = Callable[[Callable], Callable]

# Context manager types
ConditionalContext = Union[ContextManager, None]
OptionalContext = ContextManager

# Validation types
ModelFun = Callable[..., Any]
GuideFun = Callable[..., Any]
TraceDict = Dict[str, Any]
SiteDict = Dict[str, Any]

Install with Tessl CLI

npx tessl i tessl/pypi-numpyro

docs

diagnostics.md

distributions.md

handlers.md

index.md

inference.md

optimization.md

primitives.md

utilities.md

tile.json