CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

unbalanced-transport.mddocs/

Unbalanced Optimal Transport

The ot.unbalanced module provides algorithms for unbalanced optimal transport, where the marginal constraints are relaxed allowing different total masses between source and target distributions. This is particularly useful for applications involving data with outliers, noise, or when comparing distributions with naturally different masses.

Core Unbalanced Methods

Sinkhorn-based Unbalanced Transport

def ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
    """
    Solve unbalanced optimal transport using Sinkhorn algorithm with KL relaxation.
    
    Solves the unbalanced optimal transport problem:
    min <P,M> + reg*KL(P|K) + reg_m*KL(P1|a) + reg_m*KL(P^T1|b)
    where the marginal constraints are relaxed using KL divergences.
    
    Parameters:
    - a: array-like, shape (n_samples_source,)
         Source distribution. Need not sum to 1.
    - b: array-like, shape (n_samples_target,)
         Target distribution. Need not sum to 1.
    - M: array-like, shape (n_samples_source, n_samples_target)
         Ground cost matrix.
    - reg: float
         Entropic regularization parameter (>0).
    - reg_m: float or tuple of floats
         Marginal relaxation parameter(s). If float, uses same value for both
         marginals. If tuple (reg_m1, reg_m2), uses different values.
    - method: str, default='sinkhorn'
         Algorithm variant. Options: 'sinkhorn', 'sinkhorn_stabilized',
         'sinkhorn_translation_invariant'
    - numItermax: int, default=1000
         Maximum number of iterations.
    - stopThr: float, default=1e-6
         Convergence threshold on marginal violation.
    - verbose: bool, default=False
         Print iteration information.
    - log: bool, default=False
         Return optimization log.
    - warn: bool, default=True
         Warn if algorithm doesn't converge.
    
    Returns:
    - transport_plan: ndarray, shape (n_samples_source, n_samples_target)
         Unbalanced optimal transport plan.
    - log: dict (if log=True)
         Contains 'err': convergence errors, 'mass_source': final source mass,
         'mass_target': final target mass, 'u': source scaling, 'v': target scaling.
    """

def ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
    """
    Solve unbalanced optimal transport and return cost only.
    
    More efficient than sinkhorn_unbalanced() when only the cost is needed.
    
    Parameters: Same as sinkhorn_unbalanced()
    
    Returns:
    - cost: float
         Unbalanced optimal transport cost.
    - log: dict (if log=True)
    """

def ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
    """
    Unbalanced Sinkhorn-Knopp algorithm with multiplicative updates.
    
    Classic formulation using diagonal scaling matrices for unbalanced case.
    
    Parameters: Same as sinkhorn_unbalanced()
    
    Returns:
    - transport_plan: ndarray
    - log: dict (if log=True)
    """

def ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
    """
    Stabilized unbalanced Sinkhorn algorithm.
    
    Uses tau-absorption technique to prevent numerical overflow while
    handling unbalanced marginals.
    
    Parameters:
    - a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()
    - tau: float, default=1e3
         Absorption threshold for numerical stability.
    - Other parameters same as sinkhorn_unbalanced()
    
    Returns:
    - transport_plan: ndarray
    - log: dict (if log=True)
    """

def ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, c=None, rescale_plan=True, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
    """
    Translation-invariant unbalanced Sinkhorn algorithm.
    
    Uses a translation-invariant formulation that can be more numerically
    stable and allows for better initialization strategies.
    
    Parameters:
    - a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()
    - c: array-like, optional
         Translation vector for numerical stability.
    - rescale_plan: bool, default=True
         Whether to rescale the final transport plan.
    - Other parameters same as sinkhorn_unbalanced()
    
    Returns:
    - transport_plan: ndarray
    - log: dict (if log=True)
    """

Unbalanced Barycenters

def ot.unbalanced.barycenter_unbalanced(A, M, reg, reg_m, weights=None, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
    """
    Compute unbalanced Wasserstein barycenter.
    
    Finds the barycenter that minimizes the sum of unbalanced transport costs
    to all input distributions, allowing for mass creation/destruction.
    
    Parameters:
    - A: array-like, shape (n_samples, n_distributions)
         Input distributions as columns. Need not be normalized.
    - M: array-like, shape (n_samples, n_samples)
         Ground cost matrix on barycenter support.
    - reg: float
         Entropic regularization parameter.
    - reg_m: float
         Marginal relaxation parameter.
    - weights: array-like, shape (n_distributions,), optional
         Weights for barycenter combination. Default is uniform.
    - method: str, default='sinkhorn'
         Algorithm variant for unbalanced transport computation.
    - numItermax: int, default=1000
         Maximum iterations for barycenter computation.
    - stopThr: float, default=1e-6
         Convergence threshold.
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - barycenter: ndarray, shape (n_samples,)
         Unbalanced Wasserstein barycenter (may not sum to 1).
    - log: dict (if log=True)
         Contains convergence information and transport plans.
    """

def ot.unbalanced.barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
    """
    Compute unbalanced barycenter using Sinkhorn algorithm.
    
    Alternative implementation with explicit Sinkhorn iterations.
    
    Parameters: Same as barycenter_unbalanced()
    
    Returns:
    - barycenter: ndarray
    - log: dict (if log=True)
    """

def ot.unbalanced.barycenter_unbalanced_stabilized(A, M, reg, reg_m, tau=1e3, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
    """
    Compute unbalanced barycenter using stabilized algorithm.
    
    Parameters:
    - A, M, reg, reg_m, weights: Same as barycenter_unbalanced()
    - tau: float, default=1e3
         Stabilization parameter.
    - Other parameters same as barycenter_unbalanced()
    
    Returns:
    - barycenter: ndarray
    - log: dict (if log=True)
    """

MM Algorithm

def ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
    """
    Solve unbalanced optimal transport using MM (Majorization-Minimization) algorithm.
    
    Alternative optimization approach that can handle different divergences
    for marginal relaxation beyond KL divergence.
    
    Parameters:
    - a: array-like, shape (n_samples_source,)
         Source distribution.
    - b: array-like, shape (n_samples_target,)  
         Target distribution.
    - M: array-like, shape (n_samples_source, n_samples_target)
         Ground cost matrix.
    - reg: float
         Entropic regularization parameter.
    - reg_m: float or tuple
         Marginal relaxation parameter(s).
    - div: str, default='kl'
         Divergence for marginal relaxation. Options: 'kl', 'l2', 'tv'
    - G0: array-like, optional
         Initial transport plan.
    - numItermax: int, default=1000
    - stopThr: float, default=1e-6
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray, shape (n_samples_source, n_samples_target)
    - log: dict (if log=True)
    """

def ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
    """
    MM algorithm for unbalanced OT returning cost only.
    
    Parameters: Same as mm_unbalanced()
    
    Returns:
    - cost: float
         Unbalanced transport cost.
    - log: dict (if log=True)
    """

L-BFGS-B Methods

def ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):
    """
    Solve unbalanced optimal transport using L-BFGS-B optimizer.
    
    Uses quasi-Newton optimization method for solving the dual formulation
    of unbalanced optimal transport, which can be more efficient for
    large-scale problems.
    
    Parameters:
    - a: array-like, shape (n_samples_source,)
         Source distribution.
    - b: array-like, shape (n_samples_target,)
         Target distribution.
    - M: array-like, shape (n_samples_source, n_samples_target)
         Ground cost matrix.
    - reg: float
         Entropic regularization parameter.
    - reg_m: float or tuple
         Marginal relaxation parameter(s).
    - c: array-like, optional
         Translation vector for numerical stability.
    - reg_div: str, default='kl'
         Divergence type for marginal regularization.
    - G0: array-like, optional
         Initial transport plan.
    - numItermax: int, default=1000
         Maximum outer iterations.
    - numInnerItermax: int, default=10
         Maximum inner iterations for line search.
    - stopThr: float, default=1e-6
         Convergence threshold for outer loop.
    - stopThr2: float, default=1e-6
         Convergence threshold for inner loop.
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray, shape (n_samples_source, n_samples_target)
    - log: dict (if log=True)
         Contains optimization details including L-BFGS-B convergence info.
    """

def ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):
    """
    L-BFGS-B unbalanced OT returning cost only.
    
    Parameters: Same as lbfgsb_unbalanced()
    
    Returns:
    - cost: float
    - log: dict (if log=True)
    """

Regularization and Divergences

The unbalanced transport framework supports different types of marginal relaxation:

KL Divergence Relaxation

The most common choice using Kullback-Leibler divergence for marginal penalties:

KL(π₁|a) = Σᵢ π₁(i) log(π₁(i)/a(i)) - π₁(i) + a(i)

Alternative Divergences

  • L2 Penalty: div='l2' - Quadratic penalty on marginal violations
  • Total Variation: div='tv' - L1 penalty on marginal differences
  • Custom Divergences: User-defined penalty functions

Usage Examples

Basic Unbalanced Transport

import ot
import numpy as np

# Create unbalanced distributions
a = np.array([0.6, 0.4])  # Source (sums to 1.0)
b = np.array([0.2, 0.3, 0.1])  # Target (sums to 0.6)

# Cost matrix
M = np.random.rand(2, 3)

# Regularization parameters
reg = 0.1        # Entropic regularization
reg_m = 0.5      # Marginal relaxation

# Solve unbalanced transport
plan_unbalanced = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, verbose=True)
cost_unbalanced = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m)

print("Unbalanced transport plan:")
print(plan_unbalanced)
print(f"Unbalanced cost: {cost_unbalanced}")

# Check mass conservation
source_mass = np.sum(plan_unbalanced, axis=1)
target_mass = np.sum(plan_unbalanced, axis=0)
print(f"Source masses: {source_mass} (original: {a})")
print(f"Target masses: {target_mass} (original: {b})")

Different Marginal Regularizations

# Different regularization for source and target
reg_m_source = 0.3
reg_m_target = 0.7
reg_m_tuple = (reg_m_source, reg_m_target)

plan_asym = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_tuple)
print("Asymmetric marginal regularization plan:")
print(plan_asym)

Unbalanced Barycenter

# Multiple unbalanced distributions
A = np.array([[0.6, 0.2, 0.4],
              [0.4, 0.3, 0.6],
              [0.0, 0.5, 0.0]])  # 3 distributions, different masses

# Cost matrix for barycenter space
M_bary = ot.dist(np.arange(3).reshape(-1, 1))

# Compute unbalanced barycenter
reg_bary = 0.05
reg_m_bary = 0.2

barycenter = ot.unbalanced.barycenter_unbalanced(A, M_bary, reg_bary, reg_m_bary, verbose=True)

print("Unbalanced barycenter:")
print(barycenter)
print(f"Barycenter mass: {np.sum(barycenter)}")

MM Algorithm with Different Divergences

# Use L2 divergence for marginal relaxation
plan_mm_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='l2')
cost_mm_l2 = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='l2')

print(f"MM L2 cost: {cost_mm_l2}")

# Use Total Variation divergence
plan_mm_tv = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='tv')
cost_mm_tv = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='tv')

print(f"MM TV cost: {cost_mm_tv}")

Empirical Unbalanced Transport

# Generate unbalanced sample data
np.random.seed(42)
n_source, n_target = 100, 80
X_s = np.random.randn(n_source, 2)
X_t = np.random.randn(n_target, 2) + 1

# Unbalanced weights (don't sum to 1)
a_unbalanced = np.random.exponential(0.8, n_source)
b_unbalanced = np.random.exponential(1.2, n_target)

# Compute cost matrix
M_empirical = ot.dist(X_s, X_t)

# Solve unbalanced transport
reg_emp = 0.1
reg_m_emp = 0.3

plan_emp = ot.unbalanced.sinkhorn_unbalanced(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)
cost_emp = ot.unbalanced.sinkhorn_unbalanced2(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)

print(f"Empirical unbalanced cost: {cost_emp}")
print(f"Original source mass: {np.sum(a_unbalanced):.3f}")
print(f"Original target mass: {np.sum(b_unbalanced):.3f}")
print(f"Transported mass: {np.sum(plan_emp):.3f}")

Stabilized Algorithm for Extreme Cases

# Very small regularization or large costs
reg_small = 1e-4
M_large = M * 100

# Use stabilized version
plan_stable = ot.unbalanced.sinkhorn_stabilized_unbalanced(
    a, b, M_large, reg_small, reg_m, tau=1e2, verbose=True
)

print("Stabilized unbalanced transport completed")

L-BFGS-B for Large-Scale Problems

# For larger problems, L-BFGS-B can be more efficient
n_large = 500
a_large = np.random.exponential(1.0, n_large)
b_large = np.random.exponential(1.5, n_large)
M_large = np.random.rand(n_large, n_large)

# Use L-BFGS-B solver
plan_lbfgs = ot.unbalanced.lbfgsb_unbalanced(
    a_large, b_large, M_large, reg, reg_m, 
    numItermax=100, verbose=True
)
cost_lbfgs = ot.unbalanced.lbfgsb_unbalanced2(
    a_large, b_large, M_large, reg, reg_m
)

print(f"L-BFGS-B unbalanced cost: {cost_lbfgs}")

Applications

Comparing Unnormalized Data

Unbalanced transport is particularly useful when:

  • Comparing histograms or distributions that naturally have different total masses
  • Handling data with missing values or outliers
  • Robust matching in the presence of noise
  • Domain adaptation with different sample sizes

Mass Creation and Destruction

The relaxed marginal constraints allow:

  • Mass Creation: Transport plan can have row/column sums exceeding the original marginals
  • Mass Destruction: Transport plan can have row/column sums below the original marginals
  • Outlier Handling: Points with no good matches can have reduced mass

Computational Advantages

  • More robust convergence than balanced transport
  • Better numerical stability with extreme regularization parameters
  • Natural handling of datasets with different cardinalities

The ot.unbalanced module provides essential tools for real-world optimal transport applications where perfect mass conservation is not required or desired, offering both theoretical flexibility and computational advantages.

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json