Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The ot.lp module provides exact optimal transport solvers using linear programming approaches, primarily based on the network simplex algorithm. These methods compute the exact solution to the optimal transport problem without regularization, making them suitable for applications requiring precise transport plans.
def ot.lp.emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, check_marginals=True):
"""
Solve the Earth Mover's Distance problem and return optimal transport plan.
Computes the exact optimal transport plan between two discrete distributions using
the network simplex algorithm. This is the reference implementation for computing
optimal transport without regularization.
Parameters:
- a: array-like, shape (n_samples_source,)
Source distribution (histogram). Must be non-negative and sum to 1.
- b: array-like, shape (n_samples_target,)
Target distribution (histogram). Must be non-negative and sum to 1.
- M: array-like, shape (n_samples_source, n_samples_target)
Ground cost matrix between source and target samples.
- numItermax: int, default=100000
Maximum number of iterations for the network simplex algorithm.
- log: bool, default=False
If True, returns optimization information including number of iterations,
convergence status, and dual variables.
- center_dual: bool, default=True
Whether to center the dual potentials to avoid numerical issues.
- numThreads: int, default=1
Number of threads for parallel computation (when available).
- check_marginals: bool, default=True
Whether to check that marginal distributions sum to 1.
Returns:
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
Optimal transport plan matrix. Entry (i,j) represents the amount of mass
transported from source sample i to target sample j.
- log: dict (if log=True)
Dictionary containing optimization information:
- 'cost': float, optimal transport cost
- 'u': ndarray, source dual variables
- 'v': ndarray, target dual variables
- 'warning': str, convergence warning if any
- 'result_code': int, solver result code
"""
def ot.lp.emd2(a, b, M, processes=1, numItermax=100000, log=False, return_matrix=False, center_dual=True, numThreads=1, check_marginals=True):
"""
Solve EMD and return transport cost only (faster than computing full plan).
Computes only the optimal transport cost without returning the transport plan,
which can be more efficient when only the distance is needed.
Parameters:
- a: array-like, shape (n_samples_source,)
Source distribution (histogram).
- b: array-like, shape (n_samples_target,)
Target distribution (histogram).
- M: array-like, shape (n_samples_source, n_samples_target)
Ground cost matrix.
- processes: int, default=1
Number of processes for parallel computation.
- numItermax: int, default=100000
Maximum number of iterations.
- log: bool, default=False
Return optimization log.
- return_matrix: bool, default=False
Also return the transport matrix.
- center_dual: bool, default=True
Center dual potentials.
- numThreads: int, default=1
Number of threads for computation.
Returns:
- cost: float
Optimal transport cost (Wasserstein distance).
- transport_plan: ndarray (if return_matrix=True)
Optimal transport plan matrix.
- log: dict (if log=True)
Optimization information.
"""def ot.lp.emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1.0, dense=True, log=False):
"""
Solve 1D optimal transport problem with specialized efficient algorithm.
For 1D optimal transport, the optimal solution has a simple closed-form that
can be computed by sorting the samples. This is much faster than general LP.
Parameters:
- x_a: array-like, shape (n_samples_source,)
Source sample positions on the real line.
- x_b: array-like, shape (n_samples_target,)
Target sample positions on the real line.
- a: array-like, shape (n_samples_source,), optional
Source sample weights. If None, assumes uniform weights.
- b: array-like, shape (n_samples_target,), optional
Target sample weights. If None, assumes uniform weights.
- metric: str, default='sqeuclidean'
Distance metric to use. Options: 'sqeuclidean', 'euclidean',
'cityblock', 'minkowski'.
- p: float, default=1.0
Exponent for the Minkowski metric (when metric='minkowski').
- dense: bool, default=True
Whether to return a dense transport matrix or sparse representation.
- log: bool, default=False
Return optimization information.
Returns:
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
Optimal 1D transport plan matrix.
- log: dict (if log=True)
Contains 'cost' and other optimization details.
"""
def ot.lp.emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1.0, dense=True, log=False):
"""
Compute 1D optimal transport cost only.
Parameters: Same as emd_1d
Returns:
- cost: float
1D optimal transport cost.
- log: dict (if log=True)
"""
def ot.lp.wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
"""
Compute 1D Wasserstein distance between two distributions.
Efficient computation of 1D Wasserstein distance using the cumulative
distribution function approach.
Parameters:
- u_values: array-like, shape (n_u,)
Sample positions for first distribution.
- v_values: array-like, shape (n_v,)
Sample positions for second distribution.
- u_weights: array-like, shape (n_u,), optional
Sample weights for first distribution. Default is uniform.
- v_weights: array-like, shape (n_v,), optional
Sample weights for second distribution. Default is uniform.
- p: int, default=1
Order of the Wasserstein distance (1 or 2 typically used).
- require_sort: bool, default=True
Whether the input values need to be sorted.
Returns:
- distance: float
1D Wasserstein distance of order p.
"""def ot.lp.barycenter(A, M, weights=None, solver='interior-point', verbose=False, log=False):
"""
Compute Wasserstein barycenter using linear programming.
Solves the multi-marginal optimal transport problem to find the barycenter
(Fréchet mean) of a collection of distributions in the Wasserstein space.
Parameters:
- A: array-like, shape (n_samples, n_distributions)
Matrix where each column represents a discrete distribution.
- M: array-like, shape (n_samples, n_samples)
Ground cost matrix between barycenter support points.
- weights: array-like, shape (n_distributions,), optional
Weights for the barycenter computation. Default is uniform.
- solver: str, default='interior-point'
Linear programming solver to use.
- verbose: bool, default=False
Print solver information.
- log: bool, default=False
Return optimization log.
Returns:
- barycenter: ndarray, shape (n_samples,)
Wasserstein barycenter distribution.
- log: dict (if log=True)
Optimization information including convergence details.
"""
def ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=False):
"""
Compute free-support Wasserstein barycenter.
Unlike fixed-support barycenters, this method optimizes both the barycenter
weights and the support locations simultaneously.
Parameters:
- measures_locations: list of arrays
List of support points for each input measure.
- measures_weights: list of arrays
List of weights for each input measure.
- X_init: array-like, shape (k, d)
Initial barycenter support points.
- b: array-like, shape (k,), optional
Barycenter weights. Default is uniform.
- weights: array-like, shape (n_measures,), optional
Weights for averaging the measures.
- numItermax: int, default=100
Maximum number of iterations.
- stopThr: float, default=1e-7
Convergence threshold.
- verbose: bool, default=False
Print iteration information.
- log: bool, default=False
Return optimization log.
Returns:
- X: ndarray, shape (k, d)
Optimized barycenter support points.
- b: ndarray, shape (k,)
Barycenter weights.
- log: dict (if log=True)
"""
def ot.lp.generalized_free_support_barycenter(X_list, Y_list, a_list, b_list, X_init, Y_init, a=None, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=False):
"""
Compute generalized free-support barycenter for joint distributions.
Extends free-support barycenters to handle joint distributions where each
measure is defined over a product space.
Parameters:
- X_list: list of arrays
List of source support points for each measure.
- Y_list: list of arrays
List of target support points for each measure.
- a_list: list of arrays
List of source weights for each measure.
- b_list: list of arrays
List of target weights for each measure.
- X_init: array-like
Initial source barycenter support.
- Y_init: array-like
Initial target barycenter support.
- a: array-like, optional
Source barycenter weights.
- b: array-like, optional
Target barycenter weights.
- weights: array-like, optional
Measure averaging weights.
- numItermax: int, default=100
- stopThr: float, default=1e-7
- verbose: bool, default=False
- log: bool, default=False
Returns:
- X_barycenter: ndarray
- Y_barycenter: ndarray
- a_barycenter: ndarray
- b_barycenter: ndarray
- log: dict (if log=True)
"""def ot.lp.binary_search_circle(a, b, theta, reg, maxiter=500, tol=1e-6, log=False):
"""
Solve circular optimal transport using binary search.
For optimal transport on the circle (periodic domain), specialized algorithms
can exploit the circular geometry for more efficient computation.
Parameters:
- a: array-like, shape (n,)
Source distribution on the circle.
- b: array-like, shape (n,)
Target distribution on the circle.
- theta: array-like, shape (n,)
Angular positions (in [0, 2π]) for the discretization.
- reg: float
Regularization parameter.
- maxiter: int, default=500
Maximum number of binary search iterations.
- tol: float, default=1e-6
Convergence tolerance.
- log: bool, default=False
Return optimization log.
Returns:
- transport_plan: ndarray
Optimal circular transport plan.
- log: dict (if log=True)
"""
def ot.lp.wasserstein_circle(u_weights, v_weights, u_locations=None, v_locations=None, p=1, require_sort=True):
"""
Compute Wasserstein distance on the circle.
Parameters:
- u_weights: array-like
Weights for first distribution.
- v_weights: array-like
Weights for second distribution.
- u_locations: array-like, optional
Angular locations for first distribution.
- v_locations: array-like, optional
Angular locations for second distribution.
- p: int, default=1
Order of Wasserstein distance.
- require_sort: bool, default=True
Returns:
- distance: float
Circular Wasserstein distance.
"""
def ot.lp.semidiscrete_wasserstein2_unif_circle(X, u_weights=None, reg=0.0):
"""
Compute semi-discrete Wasserstein distance on circle with uniform target.
Parameters:
- X: array-like
Sample locations on the circle.
- u_weights: array-like, optional
Sample weights.
- reg: float, default=0.0
Regularization parameter.
Returns:
- distance: float
Semi-discrete circular Wasserstein distance.
"""def ot.lp.check_number_threads(numThreads):
"""
Validate and return appropriate number of threads for computation.
Parameters:
- numThreads: int
Requested number of threads.
Returns:
- validated_threads: int
Validated number of threads (clamped to available cores).
"""
def ot.lp.center_ot_dual(alpha0, beta0, a, b):
"""
Center dual potentials to avoid numerical issues.
Centers the dual variables (potentials) to improve numerical stability
of the optimal transport computation.
Parameters:
- alpha0: array-like
Source dual potentials.
- beta0: array-like
Target dual potentials.
- a: array-like
Source distribution.
- b: array-like
Target distribution.
Returns:
- alpha: ndarray
Centered source potentials.
- beta: ndarray
Centered target potentials.
"""
def ot.lp.estimate_dual_null_weights(alpha0, beta0, a, b, M):
"""
Estimate feasible dual variables when distributions have null weights.
Handles the case where some entries in the source or target distributions
are zero, which can cause numerical issues in the dual formulation.
Parameters:
- alpha0: array-like
Initial source dual variables.
- beta0: array-like
Initial target dual variables.
- a: array-like
Source distribution (may contain zeros).
- b: array-like
Target distribution (may contain zeros).
- M: array-like
Cost matrix.
Returns:
- feasible_alpha: ndarray
- feasible_beta: ndarray
"""
def ot.lp.emd_1d_sorted(a, b, x_a, x_b, metric='sqeuclidean', p=1.0):
"""
Compute 1D EMD for pre-sorted inputs (internal function).
Optimized version when sample positions are already sorted.
Parameters:
- a: array-like
Sorted source weights.
- b: array-like
Sorted target weights.
- x_a: array-like
Sorted source positions.
- x_b: array-like
Sorted target positions.
- metric: str, default='sqeuclidean'
- p: float, default=1.0
Returns:
- cost: float
1D transport cost.
"""def ot.lp.dmmot_monge_1dgrid_loss(a, b, alpha_a, alpha_b, log=False):
"""
Compute DMM-OT loss on 1D grid.
Deep Monge Map Optimal Transport (DMM-OT) loss computation for 1D grid data.
Parameters:
- a: array-like
Source distribution on 1D grid.
- b: array-like
Target distribution on 1D grid.
- alpha_a: float
Source regularization parameter.
- alpha_b: float
Target regularization parameter.
- log: bool, default=False
Return detailed computation log.
Returns:
- loss: float
DMM-OT loss value.
- log: dict (if log=True)
"""
def ot.lp.dmmot_monge_1dgrid_optimize(a, b, alpha_a, alpha_b, numItermax=1000, lr=1e-3, log=False):
"""
Optimize DMM-OT objective on 1D grid.
Parameters:
- a: array-like
Source distribution.
- b: array-like
Target distribution.
- alpha_a: float
Source regularization.
- alpha_b: float
Target regularization.
- numItermax: int, default=1000
Maximum optimization iterations.
- lr: float, default=1e-3
Learning rate.
- log: bool, default=False
Returns:
- optimized_map: ndarray
Optimized Monge map.
- log: dict (if log=True)
"""import ot
import numpy as np
# Define distributions
a = np.array([0.5, 0.5]) # Source distribution
b = np.array([0.3, 0.7]) # Target distribution
# Define cost matrix
M = np.array([[0.0, 1.0],
[1.0, 0.0]])
# Compute optimal transport plan
plan = ot.lp.emd(a, b, M)
print("Transport plan:", plan)
# Compute only the cost
cost = ot.lp.emd2(a, b, M)
print("Transport cost:", cost)# 1D samples
x_source = np.array([0.0, 1.0, 2.0])
x_target = np.array([0.5, 1.5])
# Uniform weights
a = ot.unif(3)
b = ot.unif(2)
# Compute 1D transport
plan_1d = ot.lp.emd_1d(x_source, x_target, a, b)
cost_1d = ot.lp.emd2_1d(x_source, x_target, a, b)
print("1D transport plan:", plan_1d)
print("1D transport cost:", cost_1d)# Multiple distributions
A = np.array([[0.6, 0.2],
[0.4, 0.8]]) # 2 distributions of size 2
# Cost matrix for barycenter space
M = ot.dist(np.arange(2).reshape(-1, 1))
# Compute barycenter
barycenter = ot.lp.barycenter(A, M)
print("Barycenter:", barycenter)The ot.lp module provides the foundation for exact optimal transport computation, offering high precision at the cost of computational complexity. These methods are essential when exact solutions are required or when serving as reference implementations for validating approximate methods.
Install with Tessl CLI
npx tessl i tessl/pypi-potdocs