CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pymc

Probabilistic Programming in Python: Bayesian Modeling and Probabilistic Machine Learning with PyTensor

Pending
Overview
Eval results
Files

variational.mddocs/

PyMC Variational Inference

PyMC provides comprehensive variational inference methods for fast approximate Bayesian inference. Variational methods are particularly useful for large datasets and complex models where MCMC sampling may be computationally prohibitive.

Main Variational Interface

Primary Fitting Function

import pymc as pm

def fit(n=10000, method='advi', model=None, random_seed=None, 
        start=None, inf_kwargs=None, **kwargs):
    """
    Fit variational approximation to the posterior.
    
    Parameters:
    - n (int): Number of optimization iterations (default: 10000)
    - method (str): Inference method ('advi', 'fullrank_advi', 'svgd', 'asvgd')
    - model: PyMC model (default: current context model)
    - random_seed (int): Random seed for reproducibility
    - start (dict): Starting parameter values
    - inf_kwargs (dict): Method-specific keyword arguments
    
    Returns:
    - approximation: Fitted variational approximation object
    """

# Basic variational inference
with pm.Model() as model:
    # Define model...
    approx = pm.fit(n=50000)

# Advanced configuration
approx = pm.fit(
    n=100000,
    method='fullrank_advi', 
    optimizer=pm.adam(learning_rate=0.01),
    callbacks=[pm.CheckParametersConvergence()],
    progressbar=True
)

Automatic Differentiation Variational Inference (ADVI)

Mean-Field ADVI

The default variational inference method using mean-field approximation:

from pymc.variational import ADVI

class ADVI:
    """
    Automatic Differentiation Variational Inference with mean-field approximation.
    
    Parameters:
    - model: PyMC model
    - random_seed (int): Random seed
    - start (dict): Initial parameter values
    
    Methods:
    - fit: Optimize variational parameters
    - sample: Draw samples from approximation
    """
    
    def __init__(self, model=None, random_seed=None, start=None):
        pass
    
    def fit(self, n, optimizer=None, callbacks=None, progressbar=True, **kwargs):
        """
        Fit the variational approximation.
        
        Parameters:
        - n (int): Number of optimization steps
        - optimizer: Optimization algorithm
        - callbacks (list): Callback functions
        - progressbar (bool): Show progress bar
        
        Returns:
        - approximation: Fitted approximation
        """
        pass

# Explicit ADVI usage
with pm.Model() as model:
    # Model definition...
    
    # Create ADVI inference object
    inference = pm.ADVI()
    
    # Fit approximation
    approx = inference.fit(n=50000, optimizer=pm.adam(learning_rate=0.01))
    
    # Draw samples from approximation
    trace = approx.sample(2000)

Full-Rank ADVI

ADVI with full covariance structure:

from pymc.variational import FullRankADVI

class FullRankADVI:
    """
    Full-rank ADVI with correlated posterior approximation.
    
    Parameters:
    - model: PyMC model
    - random_seed (int): Random seed
    """

# Full-rank approximation for capturing correlations
with pm.Model() as model:
    # Model with correlated parameters...
    
    inference = pm.FullRankADVI()
    approx = inference.fit(n=75000)
    
    # Full covariance matrix available
    cov_matrix = approx.cov.eval()

Stein Variational Gradient Descent

Standard SVGD

Particle-based variational inference:

from pymc.variational import SVGD

class SVGD:
    """
    Stein Variational Gradient Descent.
    
    Parameters:
    - n_particles (int): Number of particles (default: 100)
    - jitter (float): Jitter for numerical stability
    - model: PyMC model
    """
    
    def __init__(self, n_particles=100, jitter=1e-6, model=None):
        pass

# SVGD for complex posteriors
with pm.Model() as complex_model:
    # Complex model definition...
    
    inference = pm.SVGD(n_particles=200)
    approx = inference.fit(n=20000)
    
    # Particles represent the posterior
    particles = approx.sample(1000)

Amortized SVGD

from pymc.variational import ASVGD

class ASVGD:
    """
    Amortized Stein Variational Gradient Descent.
    
    Parameters:
    - n_particles (int): Number of particles
    - batch_size (int): Mini-batch size
    """

# ASVGD for large datasets with mini-batching
with pm.Model() as large_model:
    # Model with large dataset...
    
    inference = pm.ASVGD(n_particles=50, batch_size=128)
    approx = inference.fit(n=30000)

Variational Approximations

Mean-Field Approximation

Independent normal distributions for each parameter:

from pymc.variational.approximations import MeanField

class MeanField:
    """
    Mean-field approximation with independent normal distributions.
    
    Parameters:
    - local_rv (dict): Local random variables
    - model: PyMC model
    
    Methods:
    - sample: Draw samples from approximation
    - apply_replacements: Apply variational replacements
    """
    
    def sample(self, draws=1000, include_transformed=True):
        """
        Sample from mean-field approximation.
        
        Parameters:
        - draws (int): Number of samples to draw
        - include_transformed (bool): Include transformed variables
        
        Returns:
        - samples: Dictionary of parameter samples
        """
        pass

# Access approximation directly
with pm.Model() as model:
    # Model definition...
    
    # Create mean-field approximation
    mean_field = pm.MeanField()
    
    # Fit using KL divergence minimization
    approx = pm.KLqp(mean_field).fit(n=50000)

Full-Rank Approximation

Multivariate normal with full covariance:

from pymc.variational.approximations import FullRank

class FullRank:
    """
    Full-rank multivariate normal approximation.
    
    Parameters:
    - local_rv (dict): Local random variables
    - model: PyMC model
    
    Attributes:
    - cov: Covariance matrix
    - mean: Mean vector
    """

# Full-rank for capturing parameter correlations
with pm.Model() as correlated_model:
    # Model with strong parameter correlations...
    
    full_rank = pm.FullRank()
    approx = pm.KLqp(full_rank).fit(n=75000)
    
    # Access covariance structure
    posterior_cov = approx.cov.eval()
    posterior_corr = approx.std_to_corr(posterior_cov)

Empirical Approximation

Empirical distribution from particle samples:

from pymc.variational.approximations import Empirical

class Empirical:
    """
    Empirical approximation using particle samples.
    
    Parameters:
    - local_rv (dict): Local random variables  
    - size (int): Number of particles
    """

# Empirical approximation from SVGD
with pm.Model() as model:
    # Model definition...
    
    empirical = pm.Empirical(size=500)
    approx = pm.SVGD(approximation=empirical).fit(n=25000)

Optimization Algorithms

Adam Optimizer

def adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
    """
    Adam optimizer for variational inference.
    
    Parameters:
    - learning_rate (float): Step size
    - beta1 (float): Exponential decay rate for 1st moment
    - beta2 (float): Exponential decay rate for 2nd moment  
    - epsilon (float): Small constant for numerical stability
    
    Returns:
    - optimizer: Adam optimizer object
    """

# Custom Adam configuration
optimizer = pm.adam(
    learning_rate=0.005,
    beta1=0.95,
    beta2=0.999
)

approx = pm.fit(n=50000, optimizer=optimizer)

Other Optimizers

# Stochastic Gradient Descent
sgd_optimizer = pm.sgd(learning_rate=0.01)

# AdaGrad
adagrad_optimizer = pm.adagrad(learning_rate=0.1)

# RMSprop  
rmsprop_optimizer = pm.rmsprop(learning_rate=0.001, decay=0.9)

# Adamax
adamax_optimizer = pm.adamax(learning_rate=0.002)

# AdaDelta
adadelta_optimizer = pm.adadelta(learning_rate=1.0, decay=0.95)

Advanced Variational Methods

Custom Inference Classes

from pymc.variational.inference import KLqp, Inference

class KLqp(Inference):
    """
    Kullback-Leibler divergence minimization.
    
    Parameters:
    - approx: Variational approximation
    - beta (float): Regularization parameter
    """
    
    def __init__(self, approx, beta=1.0):
        pass

# Custom inference setup
with pm.Model() as model:
    # Model definition...
    
    # Custom approximation
    custom_approx = pm.MeanField()
    
    # KL divergence inference
    inference = pm.KLqp(custom_approx, beta=0.9)
    approx = inference.fit(n=40000)

Implicit Gradient Methods

from pymc.variational.inference import ImplicitGradient

class ImplicitGradient(Inference):
    """
    Implicit gradient variational inference.
    
    Parameters:
    - approx: Variational approximation
    - tk (float): Temperature parameter
    """

# Implicit gradient inference for difficult posteriors
with pm.Model() as difficult_model:
    # Model with complex geometry...
    
    implicit = pm.ImplicitGradient(pm.MeanField(), tk=1.5)
    approx = implicit.fit(n=60000)

Variational Groups and Structured Approximations

Grouping Variables

from pymc.variational.opvi import Group

class Group:
    """
    Group variables for structured approximations.
    
    Parameters:
    - group_vars (list): Variables in the group
    - approximation: Group-specific approximation
    """

# Group correlated parameters together
with pm.Model() as hierarchical_model:
    # Hierarchical model...
    
    # Group 1: Hyperparameters (mean-field)
    hyper_group = pm.Group([mu_alpha, sigma_alpha], pm.MeanField())
    
    # Group 2: Group effects (full-rank)  
    group_effects = pm.Group([alpha], pm.FullRank())
    
    # Combined approximation
    approximation = hyper_group + group_effects
    approx = pm.KLqp(approximation).fit(n=50000)

Callbacks and Monitoring

Built-in Callbacks

# Parameter convergence monitoring
convergence_cb = pm.CheckParametersConvergence(tolerance=0.01)

# Early stopping
early_stop_cb = pm.CheckParametersConvergence(tolerance=0.001, patience=5000)

# Custom callback function
def custom_callback(approx, loss_history, i):
    if i % 1000 == 0:
        current_loss = loss_history[-1]
        print(f"Iteration {i}: Loss = {current_loss:.4f}")

# Use callbacks during fitting
approx = pm.fit(
    n=50000,
    callbacks=[convergence_cb, custom_callback],
    progressbar=True
)

Sampling from Approximations

Drawing Samples

def sample_approx(n, approximation, more_replacements=None, 
                  return_inferencedata=True, **kwargs):
    """
    Sample from variational approximation.
    
    Parameters:
    - n (int): Number of samples
    - approximation: Fitted approximation
    - more_replacements (dict): Additional variable replacements
    - return_inferencedata (bool): Return ArviZ InferenceData
    
    Returns:
    - samples: Samples from approximation
    """

# Sample from fitted approximation
samples = pm.sample_approx(n=5000, approximation=approx)

# Sample with additional replacements
custom_samples = pm.sample_approx(
    n=3000, 
    approximation=approx,
    more_replacements={'custom_var': custom_replacement}
)

Integration with MCMC

# Use variational approximation to initialize MCMC
with pm.Model() as model:
    # Model definition...
    
    # Fit variational approximation
    approx = pm.fit(n=30000)
    
    # Use as MCMC initialization
    vi_samples = approx.sample(1000)
    start_point = {var: samples[var][-1] for var, samples in vi_samples.items()}
    
    # MCMC with VI initialization
    mcmc_trace = pm.sample(initvals=start_point, tune=1000, draws=2000)

Model Comparison and Diagnostics

ELBO Monitoring

# Track Evidence Lower Bound during optimization
with pm.Model() as model:
    # Model definition...
    
    # Fit with ELBO tracking
    approx = pm.fit(n=50000, progressbar=True)
    
    # Access ELBO history
    elbo_history = approx.hist
    
    # Plot convergence
    import matplotlib.pyplot as plt
    plt.plot(elbo_history)
    plt.xlabel('Iteration')
    plt.ylabel('ELBO')
    plt.title('Variational Inference Convergence')

Approximation Quality Assessment

# Compare VI approximation with true posterior (if available)
def assess_approximation_quality(approx, true_trace, var_names):
    """Compare VI approximation with MCMC samples."""
    vi_samples = approx.sample(5000)
    
    for var in var_names:
        vi_mean = vi_samples[var].mean()
        vi_std = vi_samples[var].std()
        
        mcmc_mean = true_trace[var].mean()
        mcmc_std = true_trace[var].std()
        
        print(f"{var}:")
        print(f"  VI:   mean={vi_mean:.3f}, std={vi_std:.3f}")
        print(f"  MCMC: mean={mcmc_mean:.3f}, std={mcmc_std:.3f}")

# Usage
assess_approximation_quality(approx, mcmc_trace, ['alpha', 'beta', 'sigma'])

Large-Scale Variational Inference

Mini-batch Variational Inference

# Mini-batch VI for large datasets
with pm.Model() as large_scale_model:
    # Large dataset
    X_mb = pm.Minibatch(X_large, batch_size=256)
    y_mb = pm.Minibatch(y_large, batch_size=256)
    
    # Model with mini-batched data
    alpha = pm.Normal('alpha', 0, 1)
    beta = pm.Normal('beta', 0, 1, shape=p)
    mu = alpha + pm.math.dot(X_mb, beta)
    
    # Scale likelihood for mini-batching
    n_total = X_large.shape[0]
    batch_size = 256
    scaling_factor = n_total / batch_size
    
    y_obs = pm.Normal('y_obs', mu=mu, sigma=1, observed=y_mb,
                      total_size=n_total)
    
    # Variational inference with mini-batches
    approx = pm.fit(n=100000, method='advi')

Parallel Variational Inference

# Parallel VI with multiple chains
import multiprocessing as mp

with pm.Model() as model:
    # Model definition...
    
    # Parallel VI approximations
    n_chains = mp.cpu_count()
    approximations = []
    
    for chain in range(n_chains):
        approx_chain = pm.fit(
            n=25000,
            random_seed=chain,
            progressbar=False
        )
        approximations.append(approx_chain)
    
    # Combine approximations (ensemble)
    ensemble_samples = []
    for approx in approximations:
        samples = approx.sample(1000)
        ensemble_samples.append(samples)

Usage Patterns and Best Practices

Hierarchical Models with VI

# Efficient VI for hierarchical models
with pm.Model() as hierarchical_vi:
    # Hyperparameters
    mu_mu = pm.Normal('mu_mu', 0, 10)
    sigma_mu = pm.HalfNormal('sigma_mu', 5)
    
    # Group parameters (non-centered parameterization)
    mu_raw = pm.Normal('mu_raw', 0, 1, shape=n_groups)
    mu = pm.Deterministic('mu', mu_mu + sigma_mu * mu_raw)
    
    # Likelihood
    y_obs = pm.Normal('y_obs', mu=mu[group_idx], sigma=1, observed=data)
    
    # VI works well with non-centered parameterization
    approx = pm.fit(n=50000, method='advi')

Model Selection with Variational Methods

# Compare models using variational inference
models_vi = {}
approximations = {}

for model_name, model in candidate_models.items():
    with model:
        approx = pm.fit(n=40000)
        approximations[model_name] = approx
        
        # Store ELBO for comparison
        models_vi[model_name] = {
            'elbo': approx.hist[-1],
            'n_params': len(model.free_RVs),
            'approximation': approx
        }

# Select best model by ELBO
best_model = max(models_vi.keys(), key=lambda k: models_vi[k]['elbo'])

PyMC's variational inference framework provides efficient approximate inference methods suitable for large-scale Bayesian modeling, offering significant computational advantages over MCMC while maintaining reasonable approximation quality for many practical applications.

Install with Tessl CLI

npx tessl i tessl/pypi-pymc

docs

data.md

distributions.md

gp.md

index.md

math.md

model.md

ode.md

sampling.md

stats.md

variational.md

tile.json