CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

unified-solvers.mddocs/

Unified Solvers

High-level unified interface for optimal transport solvers, providing a consistent API across different problem types and algorithms. These solvers automatically select appropriate methods based on problem characteristics and user preferences.

Capabilities

General Optimal Transport Solver

Unified solver for standard optimal transport problems with automatic algorithm selection.

def solve(a, b, M, reg=None, reg_type='entropy', method='auto', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
    """
    General optimal transport solver with automatic method selection.
    
    This function provides a unified interface to various OT solvers, automatically
    selecting the most appropriate algorithm based on problem size, regularization,
    and other parameters.
    
    Parameters:
    - a: array-like, source distribution
    - b: array-like, target distribution
    - M: array-like, cost matrix
    - reg: float, regularization parameter (None for exact transport)
    - reg_type: str, regularization type ('entropy', 'l2', 'kl', 'tv')
    - method: str, solver method ('auto', 'emd', 'sinkhorn', 'sinkhorn_log', 
             'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling', 'smooth')
    - numItermax: int, maximum number of iterations
    - stopThr: float, convergence threshold
    - verbose: bool, print solver information
    - log: bool, return optimization log
    
    Returns:
    - transport plan matrix or (plan, log) if log=True
    """

Gromov-Wasserstein Solver

Unified solver for Gromov-Wasserstein problems and variants.

def solve_gromov(C1, C2, p=None, q=None, M=None, alpha=0.0, reg=None, reg_type='entropy', method='auto', loss_fun='square_loss', armijo=False, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
    """
    General Gromov-Wasserstein solver with automatic method selection.
    
    Solves Gromov-Wasserstein and Fused Gromov-Wasserstein problems using
    appropriate algorithms based on regularization and problem characteristics.
    
    Parameters:
    - C1: array-like, cost matrix for source space
    - C2: array-like, cost matrix for target space
    - p: array-like, source distribution (uniform if None)
    - q: array-like, target distribution (uniform if None)
    - M: array-like, feature cost matrix (for Fused GW, None for pure GW)
    - alpha: float, trade-off parameter between structure and features (0=pure GW, 1=pure Wasserstein)
    - reg: float, regularization parameter (None for exact)
    - reg_type: str, regularization type ('entropy', 'l2')
    - method: str, solver method ('auto', 'conditional_gradient', 'proximal_point', 'frank_wolfe')
    - loss_fun: str or callable, loss function ('square_loss', 'kl_loss')
    - armijo: bool, use Armijo line search
    - numItermax: int, maximum iterations
    - stopThr: float, convergence threshold
    - verbose: bool, print information
    - log: bool, return optimization log
    
    Returns:
    - transport plan matrix or (plan, log) if log=True
    """

Sampling-based Solver

Solver for large-scale problems using sampling approaches.

def solve_sample(X_s, X_t, a=None, b=None, method='gromov_wasserstein_samples', reg=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
    """
    Solve optimal transport using sampling-based methods.
    
    Efficient solver for large-scale problems using sampling techniques
    to approximate optimal transport distances and plans.
    
    Parameters:
    - X_s: array-like, source samples (n_samples_s, n_features)
    - X_t: array-like, target samples (n_samples_t, n_features)
    - a: array-like, source weights (uniform if None)
    - b: array-like, target weights (uniform if None)
    - method: str, sampling method ('gromov_wasserstein_samples', 'sliced_wasserstein', 
             'max_sliced_wasserstein')
    - reg: float, regularization parameter
    - numItermax: int, maximum iterations
    - stopThr: float, convergence threshold
    - verbose: bool, print information
    - log: bool, return optimization log
    
    Returns:
    - transport plan or distance depending on method, or (result, log) if log=True
    """

Solver Configuration

Automatic Method Selection

The unified solvers use intelligent method selection based on problem characteristics:

Standard OT (solve):

  • Small problems (< 1000 samples): Exact EMD solver
  • Medium problems with regularization: Sinkhorn variants
  • Large problems: Stabilized Sinkhorn or epsilon-scaling
  • Sparse problems: Screenkhorn or greedy Sinkhorn

Gromov-Wasserstein (solve_gromov):

  • Small problems: Exact conditional gradient
  • Regularized problems: Entropic Gromov-Wasserstein
  • Large structured problems: Proximal point methods
  • Mixed structure-feature: Automatic Fused GW detection

Sampling-based (solve_sample):

  • High-dimensional data: Sliced Wasserstein approaches
  • Large-scale structured data: Sampled Gromov-Wasserstein
  • GPU acceleration: Backend-optimized sampling

Common Parameters

All unified solvers support common configuration parameters:

# Regularization types
reg_type = 'entropy'     # Entropic regularization (Sinkhorn-type)
reg_type = 'l2'          # L2 regularization (smooth OT)
reg_type = 'kl'          # KL divergence regularization
reg_type = 'tv'          # Total variation regularization

# Method selection
method = 'auto'          # Automatic method selection
method = 'exact'         # Force exact methods when possible
method = 'regularized'   # Force regularized methods
method = 'fast'          # Prioritize speed over accuracy

# Convergence control
stopThr = 1e-6          # Convergence threshold
numItermax = 1000       # Maximum iterations
verbose = True          # Print solver progress
log = True              # Return detailed optimization log

Usage Examples

Basic Optimal Transport

import ot
import numpy as np

# Create distributions
n, m = 100, 120
a = ot.utils.unif(n)
b = ot.utils.unif(m)
X = np.random.randn(n, 2)
Y = np.random.randn(m, 2)
M = ot.dist(X, Y)

# Solve with automatic method selection
plan = ot.solve(a, b, M, reg=0.1, method='auto', verbose=True)

# Solve exact transport (automatically uses EMD)
plan_exact = ot.solve(a, b, M, method='exact')

# Solve with specific regularization
plan_l2 = ot.solve(a, b, M, reg=0.01, reg_type='l2', method='smooth')

Gromov-Wasserstein Problems

# Create structured data
n_s, n_t = 50, 60
C1 = ot.dist(np.random.randn(n_s, 2))  # Source structure
C2 = ot.dist(np.random.randn(n_t, 2))  # Target structure

# Pure Gromov-Wasserstein
plan_gw = ot.solve_gromov(C1, C2, reg=0.1, method='auto')

# Fused Gromov-Wasserstein with features
X_s = np.random.randn(n_s, 3)
X_t = np.random.randn(n_t, 3)
M_features = ot.dist(X_s, X_t)

plan_fgw = ot.solve_gromov(
    C1, C2, M=M_features, alpha=0.5, 
    reg=0.1, method='auto', verbose=True
)

Large-Scale Sampling

# Large-scale problem with sampling
n_large = 10000
X_s_large = np.random.randn(n_large, 100)
X_t_large = np.random.randn(n_large, 100)

# Use sampling-based solver
result = ot.solve_sample(
    X_s_large, X_t_large, 
    method='sliced_wasserstein',
    numItermax=50,
    verbose=True,
    log=True
)

distance, log_dict = result
print(f"Sliced Wasserstein distance: {distance}")

Backend Integration

# Automatic backend detection and GPU acceleration
import torch

# PyTorch tensors (automatically detected)
a_torch = torch.ones(100) / 100
b_torch = torch.ones(120) / 120
M_torch = torch.randn(100, 120)

# Solver automatically uses PyTorch backend
plan_torch = ot.solve(a_torch, b_torch, M_torch, reg=0.1, method='auto')

# Force specific backend
with ot.backend.jax_backend():
    plan_jax = ot.solve(a, b, M, reg=0.1, method='sinkhorn')

Import Statements

import ot
from ot import solve, solve_gromov, solve_sample
from ot.solvers import solve, solve_gromov, solve_sample

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json