Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
This document covers advanced optimal transport methods provided by several specialized modules in POT, including smooth optimal transport, stochastic solvers, low-rank methods, Gaussian optimal transport, and other cutting-edge techniques for specific applications and computational scenarios.
ot.smooth)The smooth module provides algorithms for smooth optimal transport using various regularization schemes beyond entropic regularization.
def ot.smooth.smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using dual formulation.
Uses smooth regularization functions to create differentiable optimal transport
problems that can be solved efficiently with gradient-based methods.
Parameters:
- a: array-like, shape (n_samples_source,)
Source distribution.
- b: array-like, shape (n_samples_target,)
Target distribution.
- C: array-like, shape (n_samples_source, n_samples_target)
Cost matrix.
- regul: Regularization object
Regularization class instance (NegEntropy, SquaredL2, etc.).
- method: str, default='L-BFGS-B'
Optimization method for dual problem.
- numItermax: int, default=500
Maximum optimization iterations.
- log: bool, default=False
Return optimization log.
- kwargs: dict
Additional arguments for optimizer.
Returns:
- optimal_plan: ndarray, shape (n_samples_source, n_samples_target)
Smooth optimal transport plan.
- log: dict (if log=True)
Contains 'dual_variables': (f, g), 'obj_value': objective value.
"""
def ot.smooth.smooth_ot_semi_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using semi-dual formulation.
Alternative formulation that can be more efficient for certain regularizers
by optimizing over a single dual variable.
Parameters:
- a, b, C, regul: Same as smooth_ot_dual()
- method: str, default='L-BFGS-B'
- numItermax: int, default=500
- log: bool, default=False
Returns:
- optimal_plan: ndarray
- log: dict (if log=True)
"""class ot.smooth.Regularization:
"""
Base class for smooth regularization functions.
Defines the interface for regularization functions used in smooth
optimal transport. All regularizers must implement the required methods.
"""
def delta_Omega(self, ot_pot):
"""Compute regularization function value."""
raise NotImplementedError()
def grad_delta_Omega(self, ot_pot):
"""Compute gradient of regularization."""
raise NotImplementedError()
def max_Omega(self, a, b):
"""Compute maximum value of conjugate."""
raise NotImplementedError()
class ot.smooth.NegEntropy(Regularization):
"""
Negative entropy regularization: -∑ᵢⱼ Pᵢⱼ log(Pᵢⱼ).
The classic entropic regularization used in Sinkhorn algorithms,
implemented in the smooth OT framework.
Parameters:
- reg: float
Regularization strength parameter.
"""
def __init__(self, reg):
self.reg = reg
class ot.smooth.SquaredL2(Regularization):
"""
Squared L2 regularization: ½‖P‖²_F.
Quadratic regularization that promotes smooth transport plans
without the sparsity-inducing effects of entropic regularization.
Parameters:
- reg: float
Regularization strength parameter.
"""
def __init__(self, reg):
self.reg = reg
class ot.smooth.SparsityConstrained(Regularization):
"""
Sparsity-constrained regularization for sparse optimal transport.
Promotes solutions with a limited number of non-zero entries,
useful for interpretable transport or computational efficiency.
Parameters:
- reg: float
Regularization parameter.
- max_nz: int
Maximum number of non-zero entries allowed.
"""
def __init__(self, reg, max_nz):
self.reg = reg
self.max_nz = max_nzdef ot.smooth.dual_obj_grad(alpha, beta, a, b, C, regul):
"""
Compute dual objective value and gradient.
Parameters:
- alpha: array-like, shape (n_source,)
Source dual variables.
- beta: array-like, shape (n_target,)
Target dual variables.
- a, b: array-like
Source and target distributions.
- C: array-like
Cost matrix.
- regul: Regularization
Regularization object.
Returns:
- objective: float
Dual objective value.
- gradient: tuple
Gradients with respect to (alpha, beta).
"""
def ot.smooth.solve_dual(a, b, C, regul, method='L-BFGS-B', tol=1e-3, max_iter=500, verbose=False, log=False):
"""
Solve dual smooth optimal transport problem.
Parameters:
- a, b, C, regul: Standard OT parameters
- method: str, default='L-BFGS-B'
- tol: float, default=1e-3
- max_iter: int, default=500
- verbose: bool, default=False
- log: bool, default=False
Returns:
- dual_variables: tuple
Optimal dual variables (alpha, beta).
- log: dict (if log=True)
"""
def ot.smooth.get_plan_from_dual(alpha, beta, C, regul):
"""
Recover primal transport plan from dual variables.
Parameters:
- alpha, beta: array-like
Dual variables.
- C: array-like
Cost matrix.
- regul: Regularization
Regularization object.
Returns:
- transport_plan: ndarray
Primal optimal transport plan.
"""ot.stochastic)Stochastic methods for large-scale optimal transport using sampling and stochastic gradient approaches.
def ot.stochastic.sag_entropic_transport(a, b, M, reg, batch_size=1, numItermax=300, lr=None, log=False):
"""
Stochastic Average Gradient (SAG) for entropic optimal transport.
Uses SAG optimization to solve large-scale entropic OT problems efficiently
by maintaining gradient averages and reducing variance.
Parameters:
- a: array-like, shape (n_source,)
Source distribution.
- b: array-like, shape (n_target,)
Target distribution.
- M: array-like, shape (n_source, n_target)
Cost matrix.
- reg: float
Entropic regularization parameter.
- batch_size: int, default=1
Mini-batch size for stochastic updates.
- numItermax: int, default=300
Maximum iterations.
- lr: float, optional
Learning rate. If None, uses adaptive schedule.
- log: bool, default=False
Return optimization log.
Returns:
- transport_plan: ndarray
Approximate optimal transport plan.
- log: dict (if log=True)
Contains 'err': convergence errors, 'lr': learning rates used.
"""
def ot.stochastic.averaged_sgd_entropic_transport(a, b, M, reg, batch_size=1, numItermax=300, lr=None, log=False):
"""
Averaged Stochastic Gradient Descent for entropic OT.
Uses averaged SGD with Polyak-Ruppert averaging for improved convergence
in stochastic optimization of entropic transport.
Parameters: Same as sag_entropic_transport()
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""
def ot.stochastic.sgd_entropic_regularization(a, b, M, reg, batch_size=1, numItermax=300, lr=None, log=False):
"""
Standard SGD for entropic regularized optimal transport.
Basic stochastic gradient descent implementation for large-scale
entropic optimal transport problems.
Parameters: Same as sag_entropic_transport()
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""def ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size=1, numItermax=300, lr=None, log=False):
"""
Solve dual entropic optimal transport using stochastic methods.
Optimizes the dual formulation of entropic OT using stochastic gradients
for computational efficiency on large problems.
Parameters:
- a, b, M, reg: Standard OT parameters
- batch_size: int, default=1
- numItermax: int, default=300
- lr: float, optional
- log: bool, default=False
Returns:
- dual_variables: tuple
Optimal dual potentials (f, g).
- log: dict (if log=True)
"""
def ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, batch_size=1, numItermax=300, lr=None, log=False):
"""
Solve semi-dual entropic OT stochastically.
Parameters: Same as solve_dual_entropic()
Returns:
- dual_variable: ndarray
Semi-dual potential.
- log: dict (if log=True)
"""ot.lowrank)Low-rank approximation methods for efficient large-scale optimal transport.
def ot.lowrank.lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=1e-3, rank=10, numItermax=100, stopThr=1e-5, log=False):
"""
Low-rank Sinkhorn algorithm for scalable optimal transport.
Approximates the optimal transport plan using low-rank matrix factorization,
significantly reducing memory and computational requirements for large problems.
Parameters:
- X_s: array-like, shape (n_source, d)
Source samples.
- X_t: array-like, shape (n_target, d)
Target samples.
- a: array-like, shape (n_source,), optional
Source weights. Default is uniform.
- b: array-like, shape (n_target,), optional
Target weights. Default is uniform.
- reg: float, default=1e-3
Entropic regularization parameter.
- rank: int, default=10
Rank constraint for low-rank approximation.
- numItermax: int, default=100
Maximum iterations.
- stopThr: float, default=1e-5
Convergence threshold.
- log: bool, default=False
Return optimization log.
Returns:
- transport_plan: ndarray, shape (n_source, n_target)
Low-rank approximated transport plan.
- log: dict (if log=True)
Contains 'err': errors, 'Q': left factor, 'R': right factor.
Example:
# Large-scale problem with low-rank approximation
X_s = np.random.randn(1000, 50)
X_t = np.random.randn(800, 50)
plan = ot.lowrank.lowrank_sinkhorn(X_s, X_t, rank=20, reg=0.1)
"""
def ot.lowrank.compute_lr_sqeuclidean_matrix(X_s, X_t, reg, rank):
"""
Compute low-rank approximation of squared Euclidean cost matrix.
Creates efficient low-rank representation of the cost matrix without
explicitly computing the full n×m matrix.
Parameters:
- X_s: array-like, shape (n_source, d)
Source coordinates.
- X_t: array-like, shape (n_target, d)
Target coordinates.
- reg: float
Regularization for numerical stability.
- rank: int
Target rank for approximation.
Returns:
- lr_matrix: tuple
Low-rank factors (Q, R) such that cost ≈ Q @ R.T.
"""ot.gaussian)Specialized methods for optimal transport between Gaussian distributions with closed-form solutions.
def ot.gaussian.bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
"""
Compute Bures-Wasserstein mapping between Gaussian distributions.
For Gaussian distributions, the optimal transport map has a closed form
involving the means and covariance matrices.
Parameters:
- ms: array-like, shape (d,)
Mean of source Gaussian distribution.
- mt: array-like, shape (d,)
Mean of target Gaussian distribution.
- Cs: array-like, shape (d, d)
Covariance matrix of source distribution.
- Ct: array-like, shape (d, d)
Covariance matrix of target distribution.
- log: bool, default=False
Return additional information.
Returns:
- A: ndarray, shape (d, d)
Linear transformation matrix.
- b: ndarray, shape (d,)
Translation vector.
- log: dict (if log=True)
Contains intermediate computation details.
Note: The optimal map is T(x) = A @ x + b
"""
def ot.gaussian.bures_wasserstein_distance(ms, mt, Cs, Ct):
"""
Compute Bures-Wasserstein distance between Gaussian distributions.
Closed-form computation of 2-Wasserstein distance for Gaussians:
W₂²(μₛ,μₜ) = ‖mₛ-mₜ‖² + Tr(Cₛ + Cₜ - 2(Cₛ^{1/2}CₜCₛ^{1/2})^{1/2})
Parameters:
- ms, mt: array-like
Means of source and target Gaussians.
- Cs, Ct: array-like
Covariance matrices.
Returns:
- distance: float
Bures-Wasserstein distance.
"""
def ot.gaussian.bures_wasserstein_barycenter(Ms, Cs, weights=None, numItermax=100, stopThr=1e-6, verbose=False, log=False):
"""
Compute Bures-Wasserstein barycenter of Gaussian distributions.
Finds the Gaussian barycenter that minimizes the sum of Bures-Wasserstein
distances to input Gaussians.
Parameters:
- Ms: list of arrays
Means of input Gaussian distributions.
- Cs: list of arrays
Covariance matrices of input distributions.
- weights: array-like, optional
Barycenter combination weights. Default is uniform.
- numItermax: int, default=100
Maximum iterations for covariance optimization.
- stopThr: float, default=1e-6
Convergence threshold.
- verbose: bool, default=False
- log: bool, default=False
Returns:
- mean_barycenter: ndarray
Mean of barycenter Gaussian.
- cov_barycenter: ndarray
Covariance of barycenter Gaussian.
- log: dict (if log=True)
"""def ot.gaussian.empirical_bures_wasserstein_mapping(X_s, X_t, log=False):
"""
Compute empirical Bures-Wasserstein mapping from sample data.
Estimates Gaussian parameters from samples and computes the optimal
transport mapping between the fitted Gaussians.
Parameters:
- X_s: array-like, shape (n_source, d)
Source samples.
- X_t: array-like, shape (n_target, d)
Target samples.
- log: bool, default=False
Returns:
- A: ndarray, shape (d, d)
Linear mapping matrix.
- b: ndarray, shape (d,)
Translation vector.
- log: dict (if log=True)
"""
def ot.gaussian.empirical_bures_wasserstein_distance(X_s, X_t):
"""
Compute empirical Bures-Wasserstein distance from samples.
Parameters:
- X_s, X_t: array-like
Source and target samples.
Returns:
- distance: float
Empirical Bures-Wasserstein distance.
"""def ot.gaussian.gaussian_gromov_wasserstein_distance(Cs, Ct, Ys, Yt):
"""
Compute Gaussian Gromov-Wasserstein distance.
Extension of Gromov-Wasserstein to Gaussian distributions by comparing
both covariance structure and feature alignment.
Parameters:
- Cs: array-like, shape (d, d)
Source covariance matrix.
- Ct: array-like, shape (d, d)
Target covariance matrix.
- Ys: array-like, shape (d, p)
Source feature matrix.
- Yt: array-like, shape (d, q)
Target feature matrix.
Returns:
- distance: float
Gaussian GW distance.
"""
def ot.gaussian.gaussian_gromov_wasserstein_mapping(Cs, Ct, Ys, Yt):
"""
Compute Gaussian GW mapping.
Parameters: Same as gaussian_gromov_wasserstein_distance()
Returns:
- mapping: ndarray
Optimal linear mapping for Gaussian GW.
"""ot.weak)def ot.weak.weak_optimal_transport(Xs, Xt, a=None, b=None, reg=1e-3, k=10, numItermax=100, stopThr=1e-5, verbose=False, log=False):
"""
Compute weak optimal transport between point clouds.
Relaxes the optimal transport problem to allow for more flexible matching
by introducing slack variables and weak constraints.
Parameters:
- Xs: array-like, shape (n_source, d)
Source point cloud.
- Xt: array-like, shape (n_target, d)
Target point cloud.
- a, b: array-like, optional
Source and target weights.
- reg: float, default=1e-3
Regularization parameter.
- k: int, default=10
Number of weak transport components.
- numItermax: int, default=100
- stopThr: float, default=1e-5
- verbose: bool, default=False
- log: bool, default=False
Returns:
- weak_plan: ndarray
Weak optimal transport plan.
- log: dict (if log=True)
"""ot.factored)def ot.factored.factored_optimal_transport(X_s, X_t, a=None, b=None, r=10, reg=1e-3, numItermax=100, stopThr=1e-5, log=False):
"""
Compute factored optimal transport for dimension reduction.
Decomposes high-dimensional optimal transport into low-dimensional factors
for computational efficiency and interpretability.
Parameters:
- X_s: array-like, shape (n_source, d)
Source samples in high-dimensional space.
- X_t: array-like, shape (n_target, d)
Target samples.
- a, b: array-like, optional
Sample weights.
- r: int, default=10
Factorization rank (reduced dimension).
- reg: float, default=1e-3
- numItermax: int, default=100
- stopThr: float, default=1e-5
- log: bool, default=False
Returns:
- factors: tuple
Factorization (U, V) such that transport ≈ U @ V.T.
- log: dict (if log=True)
"""import ot
import numpy as np
# Generate sample data
np.random.seed(42)
n_source, n_target = 50, 60
a = ot.unif(n_source)
b = ot.unif(n_target)
C = np.random.rand(n_source, n_target)
# Different regularization types
regularizers = {
'NegEntropy': ot.smooth.NegEntropy(reg=0.1),
'SquaredL2': ot.smooth.SquaredL2(reg=0.05),
}
for name, regul in regularizers.items():
print(f"\n{name} Regularization:")
# Solve smooth OT
plan, log = ot.smooth.smooth_ot_dual(a, b, C, regul, log=True)
print(f"Objective value: {log['obj_value']:.6f}")
print(f"Transport cost: {np.sum(plan * C):.6f}")# Large-scale problem
n_large = 5000
a_large = ot.unif(n_large)
b_large = ot.unif(n_large)
M_large = np.random.rand(n_large, n_large)
# Stochastic methods for efficiency
batch_size = 32
methods = {
'SAG': ot.stochastic.sag_entropic_transport,
'Averaged SGD': ot.stochastic.averaged_sgd_entropic_transport,
'Standard SGD': ot.stochastic.sgd_entropic_regularization,
}
for name, method in methods.items():
print(f"\n{name} Method:")
plan, log = method(a_large, b_large, M_large, reg=0.1,
batch_size=batch_size, numItermax=100, log=True)
print(f"Final error: {log['err'][-1]:.2e}")
print(f"Converged in {len(log['err'])} iterations")# High-dimensional problem
d = 100
n_s, n_t = 1000, 800
X_s = np.random.randn(n_s, d)
X_t = np.random.randn(n_t, d) + 1
# Compare full vs low-rank
print("Full vs Low-Rank Sinkhorn:")
# Low-rank approximation
ranks = [5, 10, 20, 50]
for rank in ranks:
plan_lr, log_lr = ot.lowrank.lowrank_sinkhorn(
X_s, X_t, reg=0.1, rank=rank, log=True
)
print(f"Rank {rank}: Final error {log_lr['err'][-1]:.2e}")
# For comparison, compute a smaller full problem
X_s_small = X_s[:100]
X_t_small = X_t[:100]
plan_full = ot.sinkhorn(ot.unif(100), ot.unif(100),
ot.dist(X_s_small, X_t_small), reg=0.1)
print(f"Full Sinkhorn (100x100): completed")# Gaussian distributions
d = 5
ms = np.random.randn(d)
mt = np.random.randn(d)
# Random positive definite covariances
A_s = np.random.randn(d, d)
Cs = A_s @ A_s.T + 0.1 * np.eye(d)
A_t = np.random.randn(d, d)
Ct = A_t @ A_t.T + 0.1 * np.eye(d)
# Closed-form Gaussian OT
bw_dist = ot.gaussian.bures_wasserstein_distance(ms, mt, Cs, Ct)
A, b, log = ot.gaussian.bures_wasserstein_mapping(ms, mt, Cs, Ct, log=True)
print(f"Bures-Wasserstein distance: {bw_dist:.4f}")
print(f"Linear map matrix shape: {A.shape}")
print(f"Translation vector shape: {b.shape}")
# Empirical version with samples
n_samples = 500
X_s_gauss = np.random.multivariate_normal(ms, Cs, n_samples)
X_t_gauss = np.random.multivariate_normal(mt, Ct, n_samples)
emp_dist = ot.gaussian.empirical_bures_wasserstein_distance(X_s_gauss, X_t_gauss)
print(f"Empirical BW distance: {emp_dist:.4f}")# Multiple Gaussian distributions
n_gaussians = 4
means = [np.random.randn(d) for _ in range(n_gaussians)]
covs = []
for _ in range(n_gaussians):
A = np.random.randn(d, d)
cov = A @ A.T + 0.1 * np.eye(d)
covs.append(cov)
# Compute Gaussian barycenter
weights = ot.unif(n_gaussians)
mean_bary, cov_bary, log_bary = ot.gaussian.bures_wasserstein_barycenter(
means, covs, weights=weights, log=True
)
print(f"Barycenter mean: {mean_bary}")
print(f"Barycenter converged in {len(log_bary['err'])} iterations")import time
# Compare methods on medium-scale problem
n = 200
X_s = np.random.randn(n, 10)
X_t = np.random.randn(n, 10)
a = ot.unif(n)
b = ot.unif(n)
M = ot.dist(X_s, X_t)
methods_to_compare = [
("Standard Sinkhorn", lambda: ot.sinkhorn(a, b, M, reg=0.1)),
("Low-rank (rank=10)", lambda: ot.lowrank.lowrank_sinkhorn(X_s, X_t, reg=0.1, rank=10)),
("Gaussian BW", lambda: ot.gaussian.empirical_bures_wasserstein_mapping(X_s, X_t)),
]
print("Performance Comparison:")
for name, method in methods_to_compare:
start_time = time.time()
result = method()
elapsed = time.time() - start_time
print(f"{name}: {elapsed:.4f} seconds")# Study effect of different regularizations
test_regularizations = [
('Entropy λ=0.01', ot.smooth.NegEntropy(0.01)),
('Entropy λ=0.1', ot.smooth.NegEntropy(0.1)),
('L2 λ=0.01', ot.smooth.SquaredL2(0.01)),
('L2 λ=0.1', ot.smooth.SquaredL2(0.1)),
]
costs = []
sparsities = []
for name, regul in test_regularizations:
plan = ot.smooth.smooth_ot_dual(a, b, C, regul)
cost = np.sum(plan * C)
sparsity = np.mean(plan < 1e-6) # Fraction of near-zero entries
costs.append(cost)
sparsities.append(sparsity)
print(f"{name}: Cost={cost:.4f}, Sparsity={sparsity:.2f}")
print(f"\nCost range: [{min(costs):.4f}, {max(costs):.4f}]")
print(f"Sparsity range: [{min(sparsities):.2f}, {max(sparsities):.2f}]")These advanced methods extend POT's capabilities to specialized scenarios requiring specific computational properties, theoretical guarantees, or application-specific optimizations.
Install with Tessl CLI
npx tessl i tessl/pypi-potdocs