CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

weak-transport.mddocs/

Weak Optimal Transport

Weak optimal transport provides a relaxed formulation of the classical optimal transport problem where the transport plan minimizes displacement variance rather than total transport cost. This approach is particularly useful for applications where preserving local structure is more important than minimizing global transport costs.

Capabilities

Weak Optimal Transport Solver

Solve the weak optimal transport problem between empirical distributions.

def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
    """
    Solve weak optimal transport problem between two empirical distributions.
    
    The weak OT problem minimizes the displacement variance:
    γ = argmin_γ Σ_i a_i (X^a_i - (1/a_i) Σ_j γ_ij X^b_j)²
    
    subject to standard transport constraints:
    - γ 1 = a (source marginal constraint)
    - γ^T 1 = b (target marginal constraint)  
    - γ ≥ 0 (non-negativity)
    
    Parameters:
    - Xa: array-like, shape (n_samples_a, n_features), source samples
    - Xb: array-like, shape (n_samples_b, n_features), target samples
    - a: array-like, shape (n_samples_a,), source distribution (uniform if None)
    - b: array-like, shape (n_samples_b,), target distribution (uniform if None)
    - verbose: bool, print optimization information
    - log: bool, return optimization log
    - G0: array-like, initial transport plan (None for uniform initialization)
    
    Returns:
    - transport plan matrix or (plan, log) if log=True
    """

Theory and Applications

Weak vs Classical Optimal Transport

Classical Optimal Transport:

  • Minimizes total transport cost: Σ_ij γ_ij C_ij
  • Optimal for minimizing global displacement
  • Can create large local distortions

Weak Optimal Transport:

  • Minimizes displacement variance: Σ_i a_i ||X^a_i - barycenter_i||²
  • Preserves local neighborhood structure
  • Better for shape matching and morphing applications

Key Properties

  1. Local Structure Preservation: Maintains local relationships in source space
  2. Barycentric Transport: Each source point maps to a barycenter of target points
  3. Variance Minimization: Reduces spread of transported mass around barycenters
  4. Conditional Gradient: Efficiently solved using Frank-Wolfe type algorithms

Usage Examples

Basic Weak Transport

import ot
import numpy as np
import matplotlib.pyplot as plt

# Create 2D point clouds
n_source, n_target = 100, 120
np.random.seed(42)

# Source: circle
theta_s = np.linspace(0, 2*np.pi, n_source)
Xa = np.column_stack([np.cos(theta_s), np.sin(theta_s)])
Xa += 0.1 * np.random.randn(n_source, 2)  # Add noise

# Target: ellipse  
theta_t = np.linspace(0, 2*np.pi, n_target)
Xb = np.column_stack([2*np.cos(theta_t), 0.5*np.sin(theta_t)])
Xb += 0.1 * np.random.randn(n_target, 2)

# Solve weak optimal transport
plan_weak = ot.weak_optimal_transport(Xa, Xb, verbose=True, log=False)

print(f"Transport plan shape: {plan_weak.shape}")
print(f"Plan sum: {np.sum(plan_weak):.6f}")
print(f"Source marginal error: {np.max(np.abs(np.sum(plan_weak, axis=1) - 1/n_source)):.6f}")

Comparison with Classical Transport

# Compare weak vs classical optimal transport
a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)

# Classical transport
M = ot.dist(Xa, Xb)
plan_classical = ot.emd(a, b, M)

# Weak transport
plan_weak = ot.weak_optimal_transport(Xa, Xb, a, b)

# Visualize differences
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Source and target
axes[0].scatter(Xa[:, 0], Xa[:, 1], c='blue', alpha=0.6, label='Source')
axes[0].scatter(Xb[:, 0], Xb[:, 1], c='red', alpha=0.6, label='Target')
axes[0].set_title('Source and Target')
axes[0].legend()

# Classical transport visualization
for i in range(0, n_source, 5):  # Show subset of connections
    for j in range(n_target):
        if plan_classical[i, j] > 0.01:
            axes[1].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]], 
                        'k-', alpha=plan_classical[i, j]*10, linewidth=0.5)
axes[1].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)
axes[1].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)
axes[1].set_title('Classical OT')

# Weak transport visualization
for i in range(0, n_source, 5):
    for j in range(n_target):
        if plan_weak[i, j] > 0.01:
            axes[2].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]], 
                        'g-', alpha=plan_weak[i, j]*10, linewidth=0.5)
axes[2].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)
axes[2].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)
axes[2].set_title('Weak OT')

plt.tight_layout()
plt.show()

Shape Morphing Application

# Use weak transport for shape morphing
def interpolate_shapes(Xa, Xb, t=0.5):
    """Interpolate between shapes using weak transport."""
    plan = ot.weak_optimal_transport(Xa, Xb)
    
    # Compute barycenters for each source point
    barycenters = np.zeros_like(Xa)
    for i in range(len(Xa)):
        if np.sum(plan[i, :]) > 0:
            weights = plan[i, :] / np.sum(plan[i, :])
            barycenters[i] = np.average(Xb, weights=weights, axis=0)
        else:
            barycenters[i] = Xa[i]  # No transport for this point
    
    # Linear interpolation
    interpolated = (1 - t) * Xa + t * barycenters
    return interpolated

# Create morphing sequence
n_steps = 10
morphing_sequence = []
for i in range(n_steps + 1):
    t = i / n_steps
    shape_t = interpolate_shapes(Xa, Xb, t)
    morphing_sequence.append(shape_t)

# Visualize morphing
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
axes = axes.flatten()

for i, shape in enumerate(morphing_sequence[::1]):  # Show every shape
    if i < len(axes):
        axes[i].scatter(shape[:, 0], shape[:, 1], c='purple', alpha=0.7, s=20)
        axes[i].set_title(f't = {i/(len(morphing_sequence)-1):.1f}')
        axes[i].set_aspect('equal')
        axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Advanced Usage with Custom Parameters

# Advanced usage with custom initialization and logging
import time

# Custom initial transport plan (e.g., based on nearest neighbors)
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(n_neighbors=3)
nn.fit(Xb)
distances, indices = nn.kneighbors(Xa)

# Create sparse initialization
G0 = np.zeros((n_source, n_target))
for i in range(n_source):
    for j, idx in enumerate(indices[i]):
        G0[i, idx] = 1.0 / len(indices[i])

# Solve with custom initialization and detailed logging
start_time = time.time()
plan, log = ot.weak_optimal_transport(
    Xa, Xb, 
    a=ot.utils.unif(n_source),
    b=ot.utils.unif(n_target),
    G0=G0,
    verbose=True,
    log=True,
    numItermax=1000,
    stopThr=1e-9
)
solve_time = time.time() - start_time

print(f"Solver completed in {solve_time:.3f} seconds")
print(f"Final objective: {log['loss'][-1]:.6f}")
print(f"Number of iterations: {len(log['loss'])}")

# Plot convergence
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.semilogy(log['loss'])
plt.xlabel('Iteration')
plt.ylabel('Objective value')
plt.title('Convergence of Weak OT')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.imshow(plan, cmap='Blues', aspect='auto')
plt.colorbar()
plt.xlabel('Target samples')
plt.ylabel('Source samples')
plt.title('Transport Plan')
plt.tight_layout()
plt.show()

Import Statements

import ot
from ot import weak_optimal_transport
from ot.weak import weak_optimal_transport

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json