Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The ot.unbalanced module provides algorithms for unbalanced optimal transport, where the marginal constraints are relaxed allowing different total masses between source and target distributions. This is particularly useful for applications involving data with outliers, noise, or when comparing distributions with naturally different masses.
def ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
"""
Solve unbalanced optimal transport using Sinkhorn algorithm with KL relaxation.
Solves the unbalanced optimal transport problem:
min <P,M> + reg*KL(P|K) + reg_m*KL(P1|a) + reg_m*KL(P^T1|b)
where the marginal constraints are relaxed using KL divergences.
Parameters:
- a: array-like, shape (n_samples_source,)
Source distribution. Need not sum to 1.
- b: array-like, shape (n_samples_target,)
Target distribution. Need not sum to 1.
- M: array-like, shape (n_samples_source, n_samples_target)
Ground cost matrix.
- reg: float
Entropic regularization parameter (>0).
- reg_m: float or tuple of floats
Marginal relaxation parameter(s). If float, uses same value for both
marginals. If tuple (reg_m1, reg_m2), uses different values.
- method: str, default='sinkhorn'
Algorithm variant. Options: 'sinkhorn', 'sinkhorn_stabilized',
'sinkhorn_translation_invariant'
- numItermax: int, default=1000
Maximum number of iterations.
- stopThr: float, default=1e-6
Convergence threshold on marginal violation.
- verbose: bool, default=False
Print iteration information.
- log: bool, default=False
Return optimization log.
- warn: bool, default=True
Warn if algorithm doesn't converge.
Returns:
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
Unbalanced optimal transport plan.
- log: dict (if log=True)
Contains 'err': convergence errors, 'mass_source': final source mass,
'mass_target': final target mass, 'u': source scaling, 'v': target scaling.
"""
def ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
"""
Solve unbalanced optimal transport and return cost only.
More efficient than sinkhorn_unbalanced() when only the cost is needed.
Parameters: Same as sinkhorn_unbalanced()
Returns:
- cost: float
Unbalanced optimal transport cost.
- log: dict (if log=True)
"""
def ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
Unbalanced Sinkhorn-Knopp algorithm with multiplicative updates.
Classic formulation using diagonal scaling matrices for unbalanced case.
Parameters: Same as sinkhorn_unbalanced()
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""
def ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
Stabilized unbalanced Sinkhorn algorithm.
Uses tau-absorption technique to prevent numerical overflow while
handling unbalanced marginals.
Parameters:
- a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()
- tau: float, default=1e3
Absorption threshold for numerical stability.
- Other parameters same as sinkhorn_unbalanced()
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""
def ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, c=None, rescale_plan=True, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
"""
Translation-invariant unbalanced Sinkhorn algorithm.
Uses a translation-invariant formulation that can be more numerically
stable and allows for better initialization strategies.
Parameters:
- a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()
- c: array-like, optional
Translation vector for numerical stability.
- rescale_plan: bool, default=True
Whether to rescale the final transport plan.
- Other parameters same as sinkhorn_unbalanced()
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""def ot.unbalanced.barycenter_unbalanced(A, M, reg, reg_m, weights=None, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
"""
Compute unbalanced Wasserstein barycenter.
Finds the barycenter that minimizes the sum of unbalanced transport costs
to all input distributions, allowing for mass creation/destruction.
Parameters:
- A: array-like, shape (n_samples, n_distributions)
Input distributions as columns. Need not be normalized.
- M: array-like, shape (n_samples, n_samples)
Ground cost matrix on barycenter support.
- reg: float
Entropic regularization parameter.
- reg_m: float
Marginal relaxation parameter.
- weights: array-like, shape (n_distributions,), optional
Weights for barycenter combination. Default is uniform.
- method: str, default='sinkhorn'
Algorithm variant for unbalanced transport computation.
- numItermax: int, default=1000
Maximum iterations for barycenter computation.
- stopThr: float, default=1e-6
Convergence threshold.
- verbose: bool, default=False
- log: bool, default=False
Returns:
- barycenter: ndarray, shape (n_samples,)
Unbalanced Wasserstein barycenter (may not sum to 1).
- log: dict (if log=True)
Contains convergence information and transport plans.
"""
def ot.unbalanced.barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
"""
Compute unbalanced barycenter using Sinkhorn algorithm.
Alternative implementation with explicit Sinkhorn iterations.
Parameters: Same as barycenter_unbalanced()
Returns:
- barycenter: ndarray
- log: dict (if log=True)
"""
def ot.unbalanced.barycenter_unbalanced_stabilized(A, M, reg, reg_m, tau=1e3, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
"""
Compute unbalanced barycenter using stabilized algorithm.
Parameters:
- A, M, reg, reg_m, weights: Same as barycenter_unbalanced()
- tau: float, default=1e3
Stabilization parameter.
- Other parameters same as barycenter_unbalanced()
Returns:
- barycenter: ndarray
- log: dict (if log=True)
"""def ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
"""
Solve unbalanced optimal transport using MM (Majorization-Minimization) algorithm.
Alternative optimization approach that can handle different divergences
for marginal relaxation beyond KL divergence.
Parameters:
- a: array-like, shape (n_samples_source,)
Source distribution.
- b: array-like, shape (n_samples_target,)
Target distribution.
- M: array-like, shape (n_samples_source, n_samples_target)
Ground cost matrix.
- reg: float
Entropic regularization parameter.
- reg_m: float or tuple
Marginal relaxation parameter(s).
- div: str, default='kl'
Divergence for marginal relaxation. Options: 'kl', 'l2', 'tv'
- G0: array-like, optional
Initial transport plan.
- numItermax: int, default=1000
- stopThr: float, default=1e-6
- verbose: bool, default=False
- log: bool, default=False
Returns:
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
- log: dict (if log=True)
"""
def ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
"""
MM algorithm for unbalanced OT returning cost only.
Parameters: Same as mm_unbalanced()
Returns:
- cost: float
Unbalanced transport cost.
- log: dict (if log=True)
"""def ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):
"""
Solve unbalanced optimal transport using L-BFGS-B optimizer.
Uses quasi-Newton optimization method for solving the dual formulation
of unbalanced optimal transport, which can be more efficient for
large-scale problems.
Parameters:
- a: array-like, shape (n_samples_source,)
Source distribution.
- b: array-like, shape (n_samples_target,)
Target distribution.
- M: array-like, shape (n_samples_source, n_samples_target)
Ground cost matrix.
- reg: float
Entropic regularization parameter.
- reg_m: float or tuple
Marginal relaxation parameter(s).
- c: array-like, optional
Translation vector for numerical stability.
- reg_div: str, default='kl'
Divergence type for marginal regularization.
- G0: array-like, optional
Initial transport plan.
- numItermax: int, default=1000
Maximum outer iterations.
- numInnerItermax: int, default=10
Maximum inner iterations for line search.
- stopThr: float, default=1e-6
Convergence threshold for outer loop.
- stopThr2: float, default=1e-6
Convergence threshold for inner loop.
- verbose: bool, default=False
- log: bool, default=False
Returns:
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
- log: dict (if log=True)
Contains optimization details including L-BFGS-B convergence info.
"""
def ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):
"""
L-BFGS-B unbalanced OT returning cost only.
Parameters: Same as lbfgsb_unbalanced()
Returns:
- cost: float
- log: dict (if log=True)
"""The unbalanced transport framework supports different types of marginal relaxation:
The most common choice using Kullback-Leibler divergence for marginal penalties:
KL(π₁|a) = Σᵢ π₁(i) log(π₁(i)/a(i)) - π₁(i) + a(i)div='l2' - Quadratic penalty on marginal violationsdiv='tv' - L1 penalty on marginal differencesimport ot
import numpy as np
# Create unbalanced distributions
a = np.array([0.6, 0.4]) # Source (sums to 1.0)
b = np.array([0.2, 0.3, 0.1]) # Target (sums to 0.6)
# Cost matrix
M = np.random.rand(2, 3)
# Regularization parameters
reg = 0.1 # Entropic regularization
reg_m = 0.5 # Marginal relaxation
# Solve unbalanced transport
plan_unbalanced = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, verbose=True)
cost_unbalanced = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m)
print("Unbalanced transport plan:")
print(plan_unbalanced)
print(f"Unbalanced cost: {cost_unbalanced}")
# Check mass conservation
source_mass = np.sum(plan_unbalanced, axis=1)
target_mass = np.sum(plan_unbalanced, axis=0)
print(f"Source masses: {source_mass} (original: {a})")
print(f"Target masses: {target_mass} (original: {b})")# Different regularization for source and target
reg_m_source = 0.3
reg_m_target = 0.7
reg_m_tuple = (reg_m_source, reg_m_target)
plan_asym = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_tuple)
print("Asymmetric marginal regularization plan:")
print(plan_asym)# Multiple unbalanced distributions
A = np.array([[0.6, 0.2, 0.4],
[0.4, 0.3, 0.6],
[0.0, 0.5, 0.0]]) # 3 distributions, different masses
# Cost matrix for barycenter space
M_bary = ot.dist(np.arange(3).reshape(-1, 1))
# Compute unbalanced barycenter
reg_bary = 0.05
reg_m_bary = 0.2
barycenter = ot.unbalanced.barycenter_unbalanced(A, M_bary, reg_bary, reg_m_bary, verbose=True)
print("Unbalanced barycenter:")
print(barycenter)
print(f"Barycenter mass: {np.sum(barycenter)}")# Use L2 divergence for marginal relaxation
plan_mm_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='l2')
cost_mm_l2 = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='l2')
print(f"MM L2 cost: {cost_mm_l2}")
# Use Total Variation divergence
plan_mm_tv = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='tv')
cost_mm_tv = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='tv')
print(f"MM TV cost: {cost_mm_tv}")# Generate unbalanced sample data
np.random.seed(42)
n_source, n_target = 100, 80
X_s = np.random.randn(n_source, 2)
X_t = np.random.randn(n_target, 2) + 1
# Unbalanced weights (don't sum to 1)
a_unbalanced = np.random.exponential(0.8, n_source)
b_unbalanced = np.random.exponential(1.2, n_target)
# Compute cost matrix
M_empirical = ot.dist(X_s, X_t)
# Solve unbalanced transport
reg_emp = 0.1
reg_m_emp = 0.3
plan_emp = ot.unbalanced.sinkhorn_unbalanced(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)
cost_emp = ot.unbalanced.sinkhorn_unbalanced2(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)
print(f"Empirical unbalanced cost: {cost_emp}")
print(f"Original source mass: {np.sum(a_unbalanced):.3f}")
print(f"Original target mass: {np.sum(b_unbalanced):.3f}")
print(f"Transported mass: {np.sum(plan_emp):.3f}")# Very small regularization or large costs
reg_small = 1e-4
M_large = M * 100
# Use stabilized version
plan_stable = ot.unbalanced.sinkhorn_stabilized_unbalanced(
a, b, M_large, reg_small, reg_m, tau=1e2, verbose=True
)
print("Stabilized unbalanced transport completed")# For larger problems, L-BFGS-B can be more efficient
n_large = 500
a_large = np.random.exponential(1.0, n_large)
b_large = np.random.exponential(1.5, n_large)
M_large = np.random.rand(n_large, n_large)
# Use L-BFGS-B solver
plan_lbfgs = ot.unbalanced.lbfgsb_unbalanced(
a_large, b_large, M_large, reg, reg_m,
numItermax=100, verbose=True
)
cost_lbfgs = ot.unbalanced.lbfgsb_unbalanced2(
a_large, b_large, M_large, reg, reg_m
)
print(f"L-BFGS-B unbalanced cost: {cost_lbfgs}")Unbalanced transport is particularly useful when:
The relaxed marginal constraints allow:
The ot.unbalanced module provides essential tools for real-world optimal transport applications where perfect mass conservation is not required or desired, offering both theoretical flexibility and computational advantages.
Install with Tessl CLI
npx tessl i tessl/pypi-potdocs