docs
Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
npx @tessl/cli install tessl/pypi-pot@0.9.0A comprehensive Python library providing solvers for optimization problems related to Optimal Transport for signal, image processing, and machine learning. POT offers numerous algorithms including linear OT with network simplex solver, entropic regularization with Sinkhorn algorithms, Wasserstein barycenters, Gromov-Wasserstein distances, unbalanced and partial optimal transport variants, sliced Wasserstein distances, and stochastic solvers for large-scale problems.
pip install POTimport otImport specific functions:
from ot import emd, emd2, sinkhorn, sinkhorn2, gromov_wassersteinImport submodules:
import ot.lp
import ot.bregman
import ot.gromov
import ot.unbalancedimport ot
import numpy as np
# Define source and target distributions
a = np.array([1.0, 0.5]) # Source distribution (must sum to 1)
b = np.array([0.5, 1.0]) # Target distribution (must sum to 1)
# Define cost matrix
M = np.array([[0.5, 2.0],
[1.0, 0.5]])
# Compute optimal transport plan using exact solver
plan = ot.emd(a, b, M)
print("Transport plan:", plan)
# Compute transport cost
cost = ot.emd2(a, b, M)
print("Transport cost:", cost)
# Compute using entropic regularization (Sinkhorn)
reg = 0.1
plan_sinkhorn = ot.sinkhorn(a, b, M, reg)
cost_sinkhorn = ot.sinkhorn2(a, b, M, reg)
print("Sinkhorn plan:", plan_sinkhorn)
print("Sinkhorn cost:", cost_sinkhorn)POT is organized into specialized modules covering different aspects of optimal transport:
ot.lp): Exact optimal transport solvers using network simplex and linear programmingot.bregman): Entropic regularization methods including Sinkhorn algorithms and variantsot.gromov): Structured optimal transport for comparing metric spacesot.unbalanced): Methods for unbalanced optimal transport problemsot.da): Transport-based methods for domain adaptation in machine learningot.backend): Multi-framework support (NumPy, PyTorch, JAX, TensorFlow, CuPy)The library provides both high-level functions directly in the main ot module and specialized implementations in submodules, enabling users to choose the appropriate level of granularity for their applications.
Exact optimal transport computation using the Earth Mover's Distance (EMD) with network simplex solver, supporting 1D specialized solvers and free support barycenters.
def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
"""
Solve the Earth Mover's Distance problem and return optimal transport plan.
Parameters:
- a: array-like, source distribution (histogram)
- b: array-like, target distribution (histogram)
- M: array-like, cost matrix
- numItermax: int, maximum number of iterations
- log: bool, return optimization log
- center_dual: bool, center dual potentials
- numThreads: int, number of threads for parallel computation
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def emd2(a, b, M, processes=1, numItermax=100000, log=False, return_matrix=False, center_dual=True, numThreads=1):
"""
Solve EMD and return transport cost only.
Returns:
- transport cost (scalar) or (cost, log) if log=True
"""
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1.0, dense=True, log=False):
"""
Solve 1D optimal transport problem.
Parameters:
- x_a, x_b: array-like, sample positions
- a, b: array-like, sample weights (uniform if None)
- metric: str, cost metric ('sqeuclidean', 'euclidean', 'cityblock', 'minkowski')
- p: float, exponent for Minkowski metric
- dense: bool, return dense transport matrix
- log: bool, return optimization log
"""
def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
"""
Compute 1D Wasserstein distance between two distributions.
Parameters:
- u_values, v_values: array-like, sample positions
- u_weights, v_weights: array-like, sample weights
- p: int, Wasserstein distance order
- require_sort: bool, whether inputs need sorting
"""Sinkhorn algorithm and variants for solving regularized optimal transport problems, including stabilized versions, epsilon-scaling, and specialized algorithms like Greenkhorn and Screenkhorn.
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve entropic regularized optimal transport with Sinkhorn algorithm.
Parameters:
- a, b: array-like, source and target distributions
- M: array-like, cost matrix
- reg: float, regularization parameter
- method: str, algorithm variant ('sinkhorn', 'sinkhorn_log', 'sinkhorn_stabilized',
'sinkhorn_epsilon_scaling', 'greenkhorn', 'screenkhorn')
- numItermax: int, maximum iterations
- stopThr: float, convergence threshold
- verbose: bool, print information
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs):
"""
Compute Wasserstein barycenter of distributions.
Parameters:
- A: array-like, input distributions (columns)
- M: array-like, cost matrix
- reg: float, regularization parameter
- weights: array-like, barycenter weights
- method: str, algorithm ('sinkhorn', 'sinkhorn_log', etc.)
Returns:
- barycenter distribution or (barycenter, log) if log=True
"""Entropic Regularized Transport
Structured optimal transport for comparing metric spaces, including fused variants, barycenters, entropic regularization, and advanced methods like partial and semi-relaxed formulations.
def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
Compute Gromov-Wasserstein distance between metric spaces.
Parameters:
- C1, C2: array-like, cost matrices for source and target spaces
- p, q: array-like, source and target distributions
- loss_fun: str or function, loss function ('square_loss', 'kl_loss')
- alpha: float, step size parameter
- armijo: bool, use Armijo line search
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
Compute Fused Gromov-Wasserstein distance combining structure and features.
Parameters:
- M: array-like, feature cost matrix
- C1, C2: array-like, structure cost matrices
- Additional parameters as in gromov_wasserstein
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
"""
Compute Gromov-Wasserstein barycenter of metric spaces.
Parameters:
- N: int, size of barycenter space
- Cs: list, cost matrices of input spaces
- ps: list, distributions of input spaces
- p: array-like, barycenter distribution
- lambdas: array-like, barycenter weights
- loss_fun: str or function, loss function
Returns:
- barycenter cost matrix or (barycenter, log) if log=True
"""Methods for optimal transport between measures with different total masses, supporting various divergences and regularization approaches.
def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
Solve unbalanced optimal transport with KL relaxation.
Parameters:
- a, b: array-like, source and target distributions
- M: array-like, cost matrix
- reg: float, entropic regularization parameter
- reg_m: float or tuple, marginal relaxation parameter(s)
- method: str, algorithm variant
- Additional parameters as in sinkhorn
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def barycenter_unbalanced(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
Compute unbalanced Wasserstein barycenter.
Parameters:
- A: array-like, input distributions
- M: array-like, cost matrix
- reg: float, entropic regularization
- reg_m: float, marginal relaxation
- weights: array-like, barycenter weights
Returns:
- barycenter distribution or (barycenter, log) if log=True
"""Essential utilities for optimal transport including distance computation, distribution generation, timing functions, and array operations.
def dist(x1, x2=None, metric='sqeuclidean'):
"""
Compute distance matrix between samples.
Parameters:
- x1, x2: array-like, input samples
- metric: str, distance metric
Returns:
- distance matrix
"""
def unif(n, type_as=None):
"""
Generate uniform distribution.
Parameters:
- n: int, distribution size
- type_as: array-like, reference for array type
Returns:
- uniform distribution array
"""
def tic():
"""Start timer for performance measurement."""
def toc(message="Elapsed time : {} s"):
"""End timer and print elapsed time."""
def toq():
"""End timer and return elapsed time."""Efficient approximation methods using random projections for high-dimensional optimal transport, including spherical variants and max-sliced approaches.
def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
"""
Compute sliced Wasserstein distance between empirical distributions.
Parameters:
- X_s, X_t: array-like, source and target samples
- a, b: array-like, sample weights
- n_projections: int, number of random projections
- p: int, Wasserstein distance order
- projections: array-like, custom projection directions
- seed: int, random seed
- log: bool, return detailed results
Returns:
- sliced Wasserstein distance or (distance, log) if log=True
"""
def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
"""
Compute max-sliced Wasserstein distance using adversarial projections.
"""Transport-based methods for machine learning domain adaptation, including label-regularized transport and various transport classes for different adaptation scenarios.
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
"""
Solve optimal transport with label regularization using MM algorithm.
Parameters:
- a: array-like, source distribution
- labels_a: array-like, source labels
- b: array-like, target distribution
- M: array-like, cost matrix
- reg: float, entropic regularization
- eta: float, label regularization parameter
- numItermax: int, outer iterations
- numInnerItermax: int, inner iterations
Returns:
- transport plan matrix or (plan, log) if log=True
"""
class SinkhornTransport:
"""
Sinkhorn transport class for domain adaptation.
Parameters:
- reg_e: float, entropic regularization
- max_iter: int, maximum iterations
- tol: float, convergence tolerance
- verbose: bool, print information
- log: bool, keep optimization log
"""
def fit(self, Xs=None, Xt=None, ys=None, yt=None):
"""Fit transport from source to target."""
def transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):
"""Transform source samples to target domain."""Methods for optimal transport with relaxed mass constraints, allowing transport of only partial mass between distributions.
def partial_wasserstein(a, b, M, m=None, numItermax=1000000, log=False, **kwargs):
"""
Solve partial optimal transport problem.
Parameters:
- a, b: array-like, source and target distributions
- M: array-like, cost matrix
- m: float, fraction of mass to transport (default: min(sum(a), sum(b)))
- numItermax: int, maximum iterations
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
"""
Solve entropic regularized partial optimal transport.
"""Multi-framework backend system enabling computation with NumPy, PyTorch, JAX, TensorFlow, and CuPy for flexible deployment and GPU acceleration.
def get_backend(*args):
"""
Get appropriate backend for input arrays.
Parameters:
- args: arrays to determine backend from
Returns:
- backend instance
"""
def to_numpy(*args):
"""
Convert arrays to numpy format.
Parameters:
- args: arrays to convert
Returns:
- numpy arrays
"""
class Backend:
"""Base backend class defining array operations interface."""
class NumpyBackend(Backend):
"""NumPy backend implementation."""
class TorchBackend(Backend):
"""PyTorch backend implementation."""Specialized algorithms including smooth optimal transport, stochastic solvers for large-scale problems, low-rank methods, and Gaussian optimal transport.
def smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using dual formulation.
Parameters:
- a, b: array-like, source and target distributions
- C: array-like, cost matrix
- regul: Regularization instance
- method: str, optimization method
- numItermax: int, maximum iterations
- log: bool, return optimization log
Returns:
- optimal transport plan or (plan, log) if log=True
"""
def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=1e-3, rank=10, numItermax=100, stopThr=1e-5, log=False):
"""
Solve optimal transport using low-rank Sinkhorn algorithm.
Parameters:
- X_s, X_t: array-like, source and target samples
- a, b: array-like, sample weights
- reg: float, regularization parameter
- rank: int, rank constraint
- numItermax: int, maximum iterations
- stopThr: float, convergence threshold
- log: bool, return optimization log
Returns:
- transport plan or (plan, log) if log=True
"""Smooth optimal transport with dual and semi-dual formulations supporting KL divergence, L2 regularization, and sparsity constraints for regularized transport solutions.
def smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using dual formulation.
Parameters:
- a, b: array-like, source and target distributions
- C: array-like, cost matrix
- regul: Regularization, regularization instance (NegEntropy, SquaredL2, SparsityConstrained)
- method: str, optimization method
- numItermax: int, maximum iterations
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def smooth_ot_semi_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using semi-dual formulation.
"""Stochastic algorithms for large-scale optimal transport using SAG and SGD methods, enabling efficient computation for problems with many samples.
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state=None):
"""
Solve entropic regularized OT with Stochastic Average Gradient algorithm.
Parameters:
- a, b: array-like, source and target distributions
- M: array-like, cost matrix
- reg: float, regularization parameter
- numItermax: int, maximum iterations
- lr: float, learning rate
- random_state: int, random seed
Returns:
- transport plan matrix
"""
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax=10000, lr=0.1, log=False):
"""
Solve entropic regularized OT using SGD on dual formulation.
"""Algorithms for computing optimal transport regularization paths, exploring the full range from unregularized to highly regularized solutions.
def regularization_path(a, b, C, reg=1e-4, itmax=50000):
"""
Compute regularization path for optimal transport.
Parameters:
- a, b: array-like, source and target distributions
- C: array-like, cost matrix
- reg: float, final regularization parameter
- itmax: int, maximum iterations
Returns:
- gamma_list: list of regularization parameters
- Pi_list: list of corresponding transport plans
"""
def fully_relaxed_path(a, b, C, reg=1e-4, itmax=50000):
"""
Compute fully relaxed regularization path.
"""High-level unified interface providing automatic algorithm selection and consistent API across different problem types and scales.
def solve(a, b, M, reg=None, reg_type='entropy', method='auto', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
General optimal transport solver with automatic method selection.
Parameters:
- a, b: array-like, source and target distributions
- M: array-like, cost matrix
- reg: float, regularization parameter
- reg_type: str, regularization type ('entropy', 'l2', 'kl', 'tv')
- method: str, solver method ('auto', 'emd', 'sinkhorn', etc.)
- numItermax: int, maximum iterations
- stopThr: float, convergence threshold
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def solve_gromov(C1, C2, p=None, q=None, M=None, alpha=0.0, reg=None, method='auto', **kwargs):
"""
General Gromov-Wasserstein solver with automatic method selection.
"""Weak optimal transport minimizing displacement variance rather than total cost, preserving local structure for shape matching applications.
def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
"""
Solve weak optimal transport problem between empirical distributions.
Parameters:
- Xa, Xb: array-like, source and target samples
- a, b: array-like, source and target distributions
- verbose: bool, print optimization information
- log: bool, return optimization log
- G0: array-like, initial transport plan
Returns:
- transport plan matrix or (plan, log) if log=True
"""Factored optimal transport exploiting structure for efficient large-scale computation using low-rank decompositions.
def factored_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, **kwargs):
"""
Solve optimal transport using factored decomposition.
Parameters:
- Xa, Xb: array-like, source and target samples
- a, b: array-like, distributions
- verbose: bool, print information
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""# Common array types accepted by POT functions
ArrayLike = Union[numpy.ndarray, List, Tuple]
# Backend-specific array types
BackendArray = Union[numpy.ndarray, torch.Tensor, jax.numpy.ndarray, tensorflow.Tensor, cupy.ndarray]
# Log dictionary returned by functions with log=True
LogDict = Dict[str, Union[float, int, List, numpy.ndarray]]
# Transport plan matrix type
TransportPlan = numpy.ndarray # Shape: (n_samples_source, n_samples_target)
# Cost matrix type
CostMatrix = numpy.ndarray # Shape: (n_samples_source, n_samples_target)
# Distribution vector type
Distribution = numpy.ndarray # Shape: (n_samples,), non-negative, typically sums to 1