Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Smooth optimal transport methods with dual and semi-dual formulations for regularized problems. These solvers support various regularization schemes including KL divergence and L2 regularization for sparse and smooth transport solutions.
Solve optimal transport using dual formulation with various regularization approaches.
def smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using dual formulation.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
- regul: Regularization, regularization instance (NegEntropy, SquaredL2, SparsityConstrained)
- method: str, optimization method ('L-BFGS-B', 'SLSQP', etc.)
- numItermax: int, maximum number of iterations
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def solve_dual(a, b, C, regul, method='L-BFGS-B', tol=1e-3, max_iter=500, verbose=False):
"""
Solve dual optimal transport problem.
Parameters:
- a, b: array-like, source and target distributions
- C: array-like, cost matrix
- regul: Regularization, regularization instance
- method: str, scipy optimization method
- tol: float, optimization tolerance
- max_iter: int, maximum iterations
- verbose: bool, print optimization info
Returns:
- alpha, beta: dual variables
"""Solve optimal transport using semi-dual formulation for efficiency in certain scenarios.
def smooth_ot_semi_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
"""
Solve smooth optimal transport using semi-dual formulation.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
- regul: Regularization, regularization instance
- method: str, optimization method
- numItermax: int, maximum iterations
- log: bool, return optimization log
Returns:
- transport plan matrix or (plan, log) if log=True
"""
def solve_semi_dual(a, b, C, regul, method='L-BFGS-B', tol=1e-3, max_iter=500, verbose=False):
"""
Solve semi-dual optimal transport problem.
Parameters:
- a, b: array-like, source and target distributions
- C: array-like, cost matrix
- regul: Regularization, regularization instance
- method: str, scipy optimization method
- tol: float, optimization tolerance
- max_iter: int, maximum iterations
- verbose: bool, print optimization info
Returns:
- alpha: dual variable for source
"""Recover transport plans from dual variables computed by dual solvers.
def get_plan_from_dual(alpha, beta, C, regul):
"""
Recover transport plan from dual variables.
Parameters:
- alpha: array-like, dual variable for source
- beta: array-like, dual variable for target
- C: array-like, cost matrix
- regul: Regularization, regularization instance
Returns:
- transport plan matrix
"""
def get_plan_from_semi_dual(alpha, b, C, regul):
"""
Recover transport plan from semi-dual variables.
Parameters:
- alpha: array-like, dual variable for source
- b: array-like, target distribution
- C: array-like, cost matrix
- regul: Regularization, regularization instance
Returns:
- transport plan matrix
"""Support functions for smooth optimal transport including simplex projections.
def projection_simplex(V, z=1, axis=None):
"""
Project V onto the simplex scaled by z.
Parameters:
- V: array-like, input array to project
- z: float or array, scaling factor
- axis: None or int, projection axis (None: flatten, 1: row-wise, 0: column-wise)
Returns:
- projected array with same shape as V
"""Base and concrete regularization classes for different smooth optimal transport formulations.
class Regularization:
"""
Base class for regularization in smooth optimal transport.
Methods:
- delta_Omega(T): compute regularization function value
- max_Omega(T): compute maximum regularization over transport plans
"""
class NegEntropy(Regularization):
"""
Negative entropy regularization (KL divergence).
Parameters:
- gamma: float, regularization strength
"""
class SquaredL2(Regularization):
"""
Squared L2 norm regularization.
Parameters:
- gamma: float, regularization strength
"""
class SparsityConstrained(Regularization):
"""
Sparsity-constrained regularization for sparse transport plans.
Parameters:
- max_nz: int, maximum number of non-zero entries
"""import ot
import ot.smooth
import numpy as np
# Create distributions and cost matrix
n, m = 100, 80
a = ot.utils.unif(n)
b = ot.utils.unif(m)
X = np.random.randn(n, 2)
Y = np.random.randn(m, 2)
C = ot.dist(X, Y)
# Solve with negative entropy regularization
regul = ot.smooth.NegEntropy(gamma=0.1)
plan_dual = ot.smooth.smooth_ot_dual(a, b, C, regul)
plan_semi_dual = ot.smooth.smooth_ot_semi_dual(a, b, C, regul)
# Solve with L2 regularization
regul_l2 = ot.smooth.SquaredL2(gamma=0.01)
plan_l2 = ot.smooth.smooth_ot_dual(a, b, C, regul_l2)
# Sparsity-constrained transport
regul_sparse = ot.smooth.SparsityConstrained(max_nz=500)
plan_sparse = ot.smooth.smooth_ot_dual(a, b, C, regul_sparse)import ot.smooth
from ot.smooth import smooth_ot_dual, smooth_ot_semi_dual, NegEntropy, SquaredL2Install with Tessl CLI
npx tessl i tessl/pypi-potdocs