Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Weak optimal transport provides a relaxed formulation of the classical optimal transport problem where the transport plan minimizes displacement variance rather than total transport cost. This approach is particularly useful for applications where preserving local structure is more important than minimizing global transport costs.
Solve the weak optimal transport problem between empirical distributions.
def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
"""
Solve weak optimal transport problem between two empirical distributions.
The weak OT problem minimizes the displacement variance:
γ = argmin_γ Σ_i a_i (X^a_i - (1/a_i) Σ_j γ_ij X^b_j)²
subject to standard transport constraints:
- γ 1 = a (source marginal constraint)
- γ^T 1 = b (target marginal constraint)
- γ ≥ 0 (non-negativity)
Parameters:
- Xa: array-like, shape (n_samples_a, n_features), source samples
- Xb: array-like, shape (n_samples_b, n_features), target samples
- a: array-like, shape (n_samples_a,), source distribution (uniform if None)
- b: array-like, shape (n_samples_b,), target distribution (uniform if None)
- verbose: bool, print optimization information
- log: bool, return optimization log
- G0: array-like, initial transport plan (None for uniform initialization)
Returns:
- transport plan matrix or (plan, log) if log=True
"""Classical Optimal Transport:
Σ_ij γ_ij C_ijWeak Optimal Transport:
Σ_i a_i ||X^a_i - barycenter_i||²import ot
import numpy as np
import matplotlib.pyplot as plt
# Create 2D point clouds
n_source, n_target = 100, 120
np.random.seed(42)
# Source: circle
theta_s = np.linspace(0, 2*np.pi, n_source)
Xa = np.column_stack([np.cos(theta_s), np.sin(theta_s)])
Xa += 0.1 * np.random.randn(n_source, 2) # Add noise
# Target: ellipse
theta_t = np.linspace(0, 2*np.pi, n_target)
Xb = np.column_stack([2*np.cos(theta_t), 0.5*np.sin(theta_t)])
Xb += 0.1 * np.random.randn(n_target, 2)
# Solve weak optimal transport
plan_weak = ot.weak_optimal_transport(Xa, Xb, verbose=True, log=False)
print(f"Transport plan shape: {plan_weak.shape}")
print(f"Plan sum: {np.sum(plan_weak):.6f}")
print(f"Source marginal error: {np.max(np.abs(np.sum(plan_weak, axis=1) - 1/n_source)):.6f}")# Compare weak vs classical optimal transport
a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)
# Classical transport
M = ot.dist(Xa, Xb)
plan_classical = ot.emd(a, b, M)
# Weak transport
plan_weak = ot.weak_optimal_transport(Xa, Xb, a, b)
# Visualize differences
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Source and target
axes[0].scatter(Xa[:, 0], Xa[:, 1], c='blue', alpha=0.6, label='Source')
axes[0].scatter(Xb[:, 0], Xb[:, 1], c='red', alpha=0.6, label='Target')
axes[0].set_title('Source and Target')
axes[0].legend()
# Classical transport visualization
for i in range(0, n_source, 5): # Show subset of connections
for j in range(n_target):
if plan_classical[i, j] > 0.01:
axes[1].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]],
'k-', alpha=plan_classical[i, j]*10, linewidth=0.5)
axes[1].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)
axes[1].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)
axes[1].set_title('Classical OT')
# Weak transport visualization
for i in range(0, n_source, 5):
for j in range(n_target):
if plan_weak[i, j] > 0.01:
axes[2].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]],
'g-', alpha=plan_weak[i, j]*10, linewidth=0.5)
axes[2].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)
axes[2].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)
axes[2].set_title('Weak OT')
plt.tight_layout()
plt.show()# Use weak transport for shape morphing
def interpolate_shapes(Xa, Xb, t=0.5):
"""Interpolate between shapes using weak transport."""
plan = ot.weak_optimal_transport(Xa, Xb)
# Compute barycenters for each source point
barycenters = np.zeros_like(Xa)
for i in range(len(Xa)):
if np.sum(plan[i, :]) > 0:
weights = plan[i, :] / np.sum(plan[i, :])
barycenters[i] = np.average(Xb, weights=weights, axis=0)
else:
barycenters[i] = Xa[i] # No transport for this point
# Linear interpolation
interpolated = (1 - t) * Xa + t * barycenters
return interpolated
# Create morphing sequence
n_steps = 10
morphing_sequence = []
for i in range(n_steps + 1):
t = i / n_steps
shape_t = interpolate_shapes(Xa, Xb, t)
morphing_sequence.append(shape_t)
# Visualize morphing
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
axes = axes.flatten()
for i, shape in enumerate(morphing_sequence[::1]): # Show every shape
if i < len(axes):
axes[i].scatter(shape[:, 0], shape[:, 1], c='purple', alpha=0.7, s=20)
axes[i].set_title(f't = {i/(len(morphing_sequence)-1):.1f}')
axes[i].set_aspect('equal')
axes[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()# Advanced usage with custom initialization and logging
import time
# Custom initial transport plan (e.g., based on nearest neighbors)
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(n_neighbors=3)
nn.fit(Xb)
distances, indices = nn.kneighbors(Xa)
# Create sparse initialization
G0 = np.zeros((n_source, n_target))
for i in range(n_source):
for j, idx in enumerate(indices[i]):
G0[i, idx] = 1.0 / len(indices[i])
# Solve with custom initialization and detailed logging
start_time = time.time()
plan, log = ot.weak_optimal_transport(
Xa, Xb,
a=ot.utils.unif(n_source),
b=ot.utils.unif(n_target),
G0=G0,
verbose=True,
log=True,
numItermax=1000,
stopThr=1e-9
)
solve_time = time.time() - start_time
print(f"Solver completed in {solve_time:.3f} seconds")
print(f"Final objective: {log['loss'][-1]:.6f}")
print(f"Number of iterations: {len(log['loss'])}")
# Plot convergence
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.semilogy(log['loss'])
plt.xlabel('Iteration')
plt.ylabel('Objective value')
plt.title('Convergence of Weak OT')
plt.grid(True)
plt.subplot(1, 2, 2)
plt.imshow(plan, cmap='Blues', aspect='auto')
plt.colorbar()
plt.xlabel('Target samples')
plt.ylabel('Source samples')
plt.title('Transport Plan')
plt.tight_layout()
plt.show()import ot
from ot import weak_optimal_transport
from ot.weak import weak_optimal_transportInstall with Tessl CLI
npx tessl i tessl/pypi-potdocs