Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
High-level unified interface for optimal transport solvers, providing a consistent API across different problem types and algorithms. These solvers automatically select appropriate methods based on problem characteristics and user preferences.
Unified solver for standard optimal transport problems with automatic algorithm selection.
def solve(a, b, M, reg=None, reg_type='entropy', method='auto', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
General optimal transport solver with automatic method selection.
This function provides a unified interface to various OT solvers, automatically
selecting the most appropriate algorithm based on problem size, regularization,
and other parameters.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- M: array-like, cost matrix
- reg: float, regularization parameter (None for exact transport)
- reg_type: str, regularization type ('entropy', 'l2', 'kl', 'tv')
- method: str, solver method ('auto', 'emd', 'sinkhorn', 'sinkhorn_log',
'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling', 'smooth')
- numItermax: int, maximum number of iterations
- stopThr: float, convergence threshold
- verbose: bool, print solver information
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""Unified solver for Gromov-Wasserstein problems and variants.
def solve_gromov(C1, C2, p=None, q=None, M=None, alpha=0.0, reg=None, reg_type='entropy', method='auto', loss_fun='square_loss', armijo=False, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
General Gromov-Wasserstein solver with automatic method selection.
Solves Gromov-Wasserstein and Fused Gromov-Wasserstein problems using
appropriate algorithms based on regularization and problem characteristics.
Parameters:
- C1: array-like, cost matrix for source space
- C2: array-like, cost matrix for target space
- p: array-like, source distribution (uniform if None)
- q: array-like, target distribution (uniform if None)
- M: array-like, feature cost matrix (for Fused GW, None for pure GW)
- alpha: float, trade-off parameter between structure and features (0=pure GW, 1=pure Wasserstein)
- reg: float, regularization parameter (None for exact)
- reg_type: str, regularization type ('entropy', 'l2')
- method: str, solver method ('auto', 'conditional_gradient', 'proximal_point', 'frank_wolfe')
- loss_fun: str or callable, loss function ('square_loss', 'kl_loss')
- armijo: bool, use Armijo line search
- numItermax: int, maximum iterations
- stopThr: float, convergence threshold
- verbose: bool, print information
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""Solver for large-scale problems using sampling approaches.
def solve_sample(X_s, X_t, a=None, b=None, method='gromov_wasserstein_samples', reg=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
Solve optimal transport using sampling-based methods.
Efficient solver for large-scale problems using sampling techniques
to approximate optimal transport distances and plans.
Parameters:
- X_s: array-like, source samples (n_samples_s, n_features)
- X_t: array-like, target samples (n_samples_t, n_features)
- a: array-like, source weights (uniform if None)
- b: array-like, target weights (uniform if None)
- method: str, sampling method ('gromov_wasserstein_samples', 'sliced_wasserstein',
'max_sliced_wasserstein')
- reg: float, regularization parameter
- numItermax: int, maximum iterations
- stopThr: float, convergence threshold
- verbose: bool, print information
- log: bool, return optimization log
Returns:
- transport plan or distance depending on method, or (result, log) if log=True
"""The unified solvers use intelligent method selection based on problem characteristics:
Standard OT (solve):
Gromov-Wasserstein (solve_gromov):
Sampling-based (solve_sample):
All unified solvers support common configuration parameters:
# Regularization types
reg_type = 'entropy' # Entropic regularization (Sinkhorn-type)
reg_type = 'l2' # L2 regularization (smooth OT)
reg_type = 'kl' # KL divergence regularization
reg_type = 'tv' # Total variation regularization
# Method selection
method = 'auto' # Automatic method selection
method = 'exact' # Force exact methods when possible
method = 'regularized' # Force regularized methods
method = 'fast' # Prioritize speed over accuracy
# Convergence control
stopThr = 1e-6 # Convergence threshold
numItermax = 1000 # Maximum iterations
verbose = True # Print solver progress
log = True # Return detailed optimization logimport ot
import numpy as np
# Create distributions
n, m = 100, 120
a = ot.utils.unif(n)
b = ot.utils.unif(m)
X = np.random.randn(n, 2)
Y = np.random.randn(m, 2)
M = ot.dist(X, Y)
# Solve with automatic method selection
plan = ot.solve(a, b, M, reg=0.1, method='auto', verbose=True)
# Solve exact transport (automatically uses EMD)
plan_exact = ot.solve(a, b, M, method='exact')
# Solve with specific regularization
plan_l2 = ot.solve(a, b, M, reg=0.01, reg_type='l2', method='smooth')# Create structured data
n_s, n_t = 50, 60
C1 = ot.dist(np.random.randn(n_s, 2)) # Source structure
C2 = ot.dist(np.random.randn(n_t, 2)) # Target structure
# Pure Gromov-Wasserstein
plan_gw = ot.solve_gromov(C1, C2, reg=0.1, method='auto')
# Fused Gromov-Wasserstein with features
X_s = np.random.randn(n_s, 3)
X_t = np.random.randn(n_t, 3)
M_features = ot.dist(X_s, X_t)
plan_fgw = ot.solve_gromov(
C1, C2, M=M_features, alpha=0.5,
reg=0.1, method='auto', verbose=True
)# Large-scale problem with sampling
n_large = 10000
X_s_large = np.random.randn(n_large, 100)
X_t_large = np.random.randn(n_large, 100)
# Use sampling-based solver
result = ot.solve_sample(
X_s_large, X_t_large,
method='sliced_wasserstein',
numItermax=50,
verbose=True,
log=True
)
distance, log_dict = result
print(f"Sliced Wasserstein distance: {distance}")# Automatic backend detection and GPU acceleration
import torch
# PyTorch tensors (automatically detected)
a_torch = torch.ones(100) / 100
b_torch = torch.ones(120) / 120
M_torch = torch.randn(100, 120)
# Solver automatically uses PyTorch backend
plan_torch = ot.solve(a_torch, b_torch, M_torch, reg=0.1, method='auto')
# Force specific backend
with ot.backend.jax_backend():
plan_jax = ot.solve(a, b, M, reg=0.1, method='sinkhorn')import ot
from ot import solve, solve_gromov, solve_sample
from ot.solvers import solve, solve_gromov, solve_sampleInstall with Tessl CLI
npx tessl i tessl/pypi-potdocs