Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Factored optimal transport provides efficient algorithms for problems with special structure that allows factorization of the transport plan. This approach significantly reduces computational complexity for large-scale problems with structured data.
Solve optimal transport problems using factored decomposition approaches.
def factored_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, **kwargs):
"""
Solve optimal transport using factored decomposition.
This method exploits structure in the data to factorize the transport plan,
reducing computational complexity from O(n²) to approximately O(n·k) where
k << n is the factorization rank.
Parameters:
- Xa: array-like, shape (n_samples_a, n_features), source samples
- Xb: array-like, shape (n_samples_b, n_features), target samples
- a: array-like, shape (n_samples_a,), source distribution (uniform if None)
- b: array-like, shape (n_samples_b,), target distribution (uniform if None)
- verbose: bool, print optimization information
- log: bool, return optimization log and factorization details
Returns:
- transport plan matrix or (plan, log) if log=True
"""Many optimal transport problems exhibit low-rank structure that can be exploited:
Standard Transport Plan: γ ∈ R^(n×m) with O(nm) complexity
Factored Transport Plan: γ ≈ UV^T where U ∈ R^(n×k), V ∈ R^(m×k) with O((n+m)k) complexity
Factored transport is particularly effective for:
import ot
import numpy as np
import matplotlib.pyplot as plt
# Create high-dimensional data with low-rank structure
n_source, n_target = 1000, 1200
n_features = 50
rank = 5
# Generate low-rank source data
U_source = np.random.randn(n_source, rank)
V_source = np.random.randn(rank, n_features)
Xa = U_source @ V_source + 0.1 * np.random.randn(n_source, n_features)
# Generate related target data
U_target = U_source + 0.5 * np.random.randn(n_source, rank)
U_target = np.vstack([U_target, np.random.randn(n_target - n_source, rank)])
V_target = V_source + 0.3 * np.random.randn(rank, n_features)
Xb = U_target @ V_target + 0.1 * np.random.randn(n_target, n_features)
# Solve using factored transport
plan_factored = ot.factored_optimal_transport(Xa, Xb, verbose=True, log=False)
print(f"Transport plan shape: {plan_factored.shape}")
print(f"Plan sparsity: {np.sum(plan_factored > 1e-8) / plan_factored.size:.4f}")# Compare computational efficiency
import time
# Standard optimal transport (for smaller problem)
n_small = 200
Xa_small = Xa[:n_small]
Xb_small = Xb[:n_small]
a_small = ot.utils.unif(n_small)
b_small = ot.utils.unif(n_small)
# Standard EMD
start_time = time.time()
M_small = ot.dist(Xa_small, Xb_small)
plan_emd = ot.emd(a_small, b_small, M_small)
time_emd = time.time() - start_time
# Sinkhorn
start_time = time.time()
plan_sinkhorn = ot.sinkhorn(a_small, b_small, M_small, reg=0.1)
time_sinkhorn = time.time() - start_time
# Factored transport (same problem)
start_time = time.time()
plan_factored_small = ot.factored_optimal_transport(Xa_small, Xb_small)
time_factored = time.time() - start_time
print(f"Timing comparison (n={n_small}):")
print(f" EMD: {time_emd:.4f}s")
print(f" Sinkhorn: {time_sinkhorn:.4f}s")
print(f" Factored: {time_factored:.4f}s")
# Large-scale problem (only factored transport feasible)
print(f"\\nLarge-scale problem (n_source={n_source}, n_target={n_target}):")
start_time = time.time()
plan_large = ot.factored_optimal_transport(Xa, Xb, verbose=False)
time_large = time.time() - start_time
print(f" Factored transport: {time_large:.4f}s")# Example with Gaussian mixtures (natural factorization)
from sklearn.mixture import GaussianMixture
# Create Gaussian mixture data
n_components = 3
n_samples_per_comp = 300
# Source mixture
gmm_source = GaussianMixture(n_components=n_components, random_state=42)
Xa_gmm = np.vstack([
np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], n_samples_per_comp),
np.random.multivariate_normal([3, 0], [[1, -0.3], [-0.3, 1]], n_samples_per_comp),
np.random.multivariate_normal([1.5, 3], [[0.8, 0.2], [0.2, 0.8]], n_samples_per_comp)
])
# Target mixture (shifted and rotated)
theta = np.pi / 6
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Xb_gmm = np.vstack([
np.random.multivariate_normal([1, 1], [[1.2, 0.4], [0.4, 1.2]], n_samples_per_comp),
np.random.multivariate_normal([4, 1], [[1, -0.4], [-0.4, 1]], n_samples_per_comp),
np.random.multivariate_normal([2.5, 4], [[0.9, 0.3], [0.3, 0.9]], n_samples_per_comp)
]) @ R.T
# Solve with factored transport
plan_gmm, log_gmm = ot.factored_optimal_transport(
Xa_gmm, Xb_gmm,
verbose=True,
log=True
)
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Source data
axes[0].scatter(Xa_gmm[:, 0], Xa_gmm[:, 1], alpha=0.6, c='blue')
axes[0].set_title('Source Distribution')
axes[0].set_aspect('equal')
# Target data
axes[1].scatter(Xb_gmm[:, 0], Xb_gmm[:, 1], alpha=0.6, c='red')
axes[1].set_title('Target Distribution')
axes[1].set_aspect('equal')
# Transport plan visualization
im = axes[2].imshow(plan_gmm, cmap='Blues', aspect='auto')
axes[2].set_xlabel('Target samples')
axes[2].set_ylabel('Source samples')
axes[2].set_title('Factored Transport Plan')
plt.colorbar(im, ax=axes[2])
plt.tight_layout()
plt.show()
if 'factorization_rank' in log_gmm:
print(f"Effective factorization rank: {log_gmm['factorization_rank']}")# Example with time series data
from sklearn.decomposition import PCA
# Generate time series with shared temporal patterns
t = np.linspace(0, 10, 100)
n_series_source = 200
n_series_target = 250
# Base temporal patterns
patterns = np.array([
np.sin(t),
np.cos(t),
np.sin(2*t),
np.exp(-t/5) * np.sin(t)
]).T
# Source time series (linear combinations of patterns)
weights_source = np.random.exponential(1, (n_series_source, 4))
Xa_ts = weights_source @ patterns.T + 0.1 * np.random.randn(n_series_source, len(t))
# Target time series (shifted patterns)
weights_target = np.random.exponential(1.2, (n_series_target, 4))
patterns_shifted = np.roll(patterns, 5, axis=0) # Temporal shift
Xb_ts = weights_target @ patterns_shifted.T + 0.1 * np.random.randn(n_series_target, len(t))
# Apply PCA preprocessing to enhance structure
pca = PCA(n_components=10)
Xa_ts_pca = pca.fit_transform(Xa_ts)
Xb_ts_pca = pca.transform(Xb_ts)
# Factored transport on time series
plan_ts = ot.factored_optimal_transport(Xa_ts_pca, Xb_ts_pca, verbose=True)
# Visualize sample time series and their transport
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Sample source time series
for i in range(5):
axes[0, 0].plot(t, Xa_ts[i], alpha=0.7)
axes[0, 0].set_title('Sample Source Time Series')
axes[0, 0].set_xlabel('Time')
# Sample target time series
for i in range(5):
axes[0, 1].plot(t, Xb_ts[i], alpha=0.7)
axes[0, 1].set_title('Sample Target Time Series')
axes[0, 1].set_xlabel('Time')
# PCA representation
axes[1, 0].scatter(Xa_ts_pca[:, 0], Xa_ts_pca[:, 1], alpha=0.6, label='Source')
axes[1, 0].scatter(Xb_ts_pca[:, 0], Xb_ts_pca[:, 1], alpha=0.6, label='Target')
axes[1, 0].set_xlabel('PC1')
axes[1, 0].set_ylabel('PC2')
axes[1, 0].set_title('PCA Representation')
axes[1, 0].legend()
# Transport plan sparsity pattern
axes[1, 1].spy(plan_ts > 1e-6, markersize=0.1)
axes[1, 1].set_title('Transport Plan Sparsity')
axes[1, 1].set_xlabel('Target series')
axes[1, 1].set_ylabel('Source series')
plt.tight_layout()
plt.show()import ot
from ot import factored_optimal_transport
from ot.factored import factored_optimal_transportInstall with Tessl CLI
npx tessl i tessl/pypi-potdocs