Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Algorithms for computing optimal transport regularization paths, exploring the full range of regularization parameters from unregularized to highly regularized solutions. These methods are particularly useful for understanding the trade-off between sparsity and transport cost.
Compute the complete regularization path for optimal transport problems.
def regularization_path(a, b, C, reg=1e-4, itmax=50000):
"""
Compute regularization path for optimal transport.
This function computes the complete path of transport plans as regularization
varies from 0 (unregularized) to infinity (uniform transport).
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
- reg: float, final regularization parameter
- itmax: int, maximum number of iterations
Returns:
- gamma_list: list of regularization parameters along path
- Pi_list: list of corresponding transport plans
"""
def compute_transport_plan(gamma, gamma_list, Pi_list):
"""
Compute transport plan for specific regularization parameter.
Parameters:
- gamma: float, target regularization parameter
- gamma_list: list, regularization parameters from path computation
- Pi_list: list, transport plans corresponding to gamma_list
Returns:
- transport plan matrix interpolated to gamma
"""Specialized algorithms for different relaxation scenarios in optimal transport.
def fully_relaxed_path(a, b, C, reg=1e-4, itmax=50000):
"""
Compute fully relaxed regularization path.
Solves the optimal transport problem with both source and target
marginal constraints relaxed along the regularization path.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
- reg: float, regularization parameter
- itmax: int, maximum iterations
Returns:
- gamma_list: regularization parameters
- Pi_list: corresponding transport plans
"""
def semi_relaxed_path(a, b, C, reg=1e-4, itmax=50000):
"""
Compute semi-relaxed regularization path.
Solves optimal transport with only target marginal constraints
relaxed along the regularization path.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
- reg: float, regularization parameter
- itmax: int, maximum iterations
Returns:
- gamma_list: regularization parameters
- Pi_list: corresponding transport plans
"""Functions for reformulating optimal transport problems as LASSO regression problems.
def recast_ot_as_lasso(a, b, C):
"""
Recast optimal transport problem as LASSO regression.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
Returns:
- H: array-like, design matrix for LASSO formulation
- y: array-like, target vector for LASSO formulation
- c: array-like, cost vector
"""
def recast_semi_relaxed_as_lasso(a, b, C):
"""
Recast semi-relaxed optimal transport as LASSO regression.
Parameters:
- a: array-like, source distribution
- b: array-like, target distribution
- C: array-like, cost matrix
Returns:
- H: array-like, design matrix for semi-relaxed LASSO
- y: array-like, target vector
- c: array-like, cost vector
"""Internal utilities for efficient path computation and updates.
def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma):
"""
Compute next regularization parameter in OT path.
Parameters:
- phi: array-like, current solution
- delta: array-like, search direction
- HtH: array-like, Hessian matrix H^T H
- Hty: array-like, gradient vector H^T y
- c: array-like, cost vector
- active_index: list, indices of active variables
- current_gamma: float, current regularization parameter
Returns:
- next_gamma: float, next regularization parameter
"""
def semi_relaxed_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma, n):
"""
Compute next gamma for semi-relaxed path.
Parameters:
- phi: array-like, current solution
- delta: array-like, search direction
- HtH: array-like, Hessian matrix
- Hty: array-like, gradient vector
- c: array-like, cost vector
- active_index: list, active variable indices
- current_gamma: float, current regularization
- n: int, source distribution size
Returns:
- next_gamma: float, next regularization parameter
"""
def compute_next_removal(phi, delta, current_gamma):
"""
Compute next variable removal point in path.
Parameters:
- phi: array-like, current solution
- delta: array-like, search direction
- current_gamma: float, current regularization
Returns:
- removal_gamma: float, regularization at which variable becomes zero
- removal_index: int, index of variable to remove
"""Efficient matrix operations for path computation algorithms.
def complement_schur(M_current, b, d, id_pop):
"""
Compute Schur complement for matrix updates.
Parameters:
- M_current: array-like, current matrix
- b: array-like, vector for update
- d: float, diagonal element
- id_pop: int, index to remove
Returns:
- updated matrix after Schur complement operation
"""
def construct_augmented_H(active_index, m, Hc, HrHr):
"""
Construct augmented design matrix H.
Parameters:
- active_index: list, active variable indices
- m: int, number of constraints
- Hc: array-like, constraint matrix
- HrHr: array-like, reduced Hessian
Returns:
- H_aug: augmented design matrix
"""import ot
import ot.regpath
import numpy as np
import matplotlib.pyplot as plt
# Create optimal transport problem
n, m = 50, 60
a = ot.utils.unif(n)
b = ot.utils.unif(m)
X = np.random.randn(n, 2)
Y = np.random.randn(m, 2)
C = ot.dist(X, Y)
# Compute full regularization path
gamma_list, Pi_list = ot.regpath.regularization_path(a, b, C, reg=1e-3)
# Plot regularization path (sparsity vs regularization)
sparsity = [np.sum(Pi > 1e-8) / Pi.size for Pi in Pi_list]
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.plot(gamma_list, sparsity, 'b-', linewidth=2)
plt.xlabel('Regularization parameter')
plt.ylabel('Sparsity (fraction of non-zeros)')
plt.title('Regularization Path')
# Compute specific transport plan
gamma_target = 0.01
Pi_target = ot.regpath.compute_transport_plan(gamma_target, gamma_list, Pi_list)
# Compare with semi-relaxed path
gamma_list_semi, Pi_list_semi = ot.regpath.semi_relaxed_path(a, b, C, reg=1e-3)
sparsity_semi = [np.sum(Pi > 1e-8) / Pi.size for Pi in Pi_list_semi]
plt.subplot(1, 2, 2)
plt.plot(gamma_list, sparsity, 'b-', label='Fully constrained', linewidth=2)
plt.plot(gamma_list_semi, sparsity_semi, 'r--', label='Semi-relaxed', linewidth=2)
plt.xlabel('Regularization parameter')
plt.ylabel('Sparsity')
plt.title('Path Comparison')
plt.legend()
plt.show()import ot.regpath
from ot.regpath import regularization_path, fully_relaxed_path, semi_relaxed_path
from ot.regpath import compute_transport_planInstall with Tessl CLI
npx tessl i tessl/pypi-potdocs