CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

factored-transport.mddocs/

Factored Optimal Transport

Factored optimal transport provides efficient algorithms for problems with special structure that allows factorization of the transport plan. This approach significantly reduces computational complexity for large-scale problems with structured data.

Capabilities

Factored Optimal Transport Solver

Solve optimal transport problems using factored decomposition approaches.

def factored_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, **kwargs):
    """
    Solve optimal transport using factored decomposition.
    
    This method exploits structure in the data to factorize the transport plan,
    reducing computational complexity from O(n²) to approximately O(n·k) where
    k << n is the factorization rank.
    
    Parameters:
    - Xa: array-like, shape (n_samples_a, n_features), source samples
    - Xb: array-like, shape (n_samples_b, n_features), target samples
    - a: array-like, shape (n_samples_a,), source distribution (uniform if None)
    - b: array-like, shape (n_samples_b,), target distribution (uniform if None)
    - verbose: bool, print optimization information
    - log: bool, return optimization log and factorization details
    
    Returns:
    - transport plan matrix or (plan, log) if log=True
    """

Factorization Approaches

Low-Rank Transport Plans

Many optimal transport problems exhibit low-rank structure that can be exploited:

Standard Transport Plan: γ ∈ R^(n×m) with O(nm) complexity Factored Transport Plan: γ ≈ UV^T where U ∈ R^(n×k), V ∈ R^(m×k) with O((n+m)k) complexity

Structured Data Scenarios

Factored transport is particularly effective for:

  1. Gaussian Distributions: Natural low-rank structure in transport plans
  2. Time Series: Temporal structure enables efficient factorization
  3. Images: Spatial correlation allows patch-based factorization
  4. Graph Data: Community structure supports block-wise transport
  5. High-dimensional Data: Manifold structure enables dimensionality reduction

Usage Examples

Basic Factored Transport

import ot
import numpy as np
import matplotlib.pyplot as plt

# Create high-dimensional data with low-rank structure
n_source, n_target = 1000, 1200
n_features = 50
rank = 5

# Generate low-rank source data
U_source = np.random.randn(n_source, rank)
V_source = np.random.randn(rank, n_features)
Xa = U_source @ V_source + 0.1 * np.random.randn(n_source, n_features)

# Generate related target data
U_target = U_source + 0.5 * np.random.randn(n_source, rank)
U_target = np.vstack([U_target, np.random.randn(n_target - n_source, rank)])
V_target = V_source + 0.3 * np.random.randn(rank, n_features)
Xb = U_target @ V_target + 0.1 * np.random.randn(n_target, n_features)

# Solve using factored transport
plan_factored = ot.factored_optimal_transport(Xa, Xb, verbose=True, log=False)

print(f"Transport plan shape: {plan_factored.shape}")
print(f"Plan sparsity: {np.sum(plan_factored > 1e-8) / plan_factored.size:.4f}")

Comparison with Standard Methods

# Compare computational efficiency
import time

# Standard optimal transport (for smaller problem)
n_small = 200
Xa_small = Xa[:n_small]
Xb_small = Xb[:n_small]

a_small = ot.utils.unif(n_small)
b_small = ot.utils.unif(n_small)

# Standard EMD
start_time = time.time()
M_small = ot.dist(Xa_small, Xb_small)
plan_emd = ot.emd(a_small, b_small, M_small)
time_emd = time.time() - start_time

# Sinkhorn
start_time = time.time()
plan_sinkhorn = ot.sinkhorn(a_small, b_small, M_small, reg=0.1)
time_sinkhorn = time.time() - start_time

# Factored transport (same problem)
start_time = time.time()
plan_factored_small = ot.factored_optimal_transport(Xa_small, Xb_small)
time_factored = time.time() - start_time

print(f"Timing comparison (n={n_small}):")
print(f"  EMD: {time_emd:.4f}s")
print(f"  Sinkhorn: {time_sinkhorn:.4f}s") 
print(f"  Factored: {time_factored:.4f}s")

# Large-scale problem (only factored transport feasible)
print(f"\\nLarge-scale problem (n_source={n_source}, n_target={n_target}):")
start_time = time.time()
plan_large = ot.factored_optimal_transport(Xa, Xb, verbose=False)
time_large = time.time() - start_time
print(f"  Factored transport: {time_large:.4f}s")

Gaussian Mixture Example

# Example with Gaussian mixtures (natural factorization)
from sklearn.mixture import GaussianMixture

# Create Gaussian mixture data
n_components = 3
n_samples_per_comp = 300

# Source mixture
gmm_source = GaussianMixture(n_components=n_components, random_state=42)
Xa_gmm = np.vstack([
    np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], n_samples_per_comp),
    np.random.multivariate_normal([3, 0], [[1, -0.3], [-0.3, 1]], n_samples_per_comp),
    np.random.multivariate_normal([1.5, 3], [[0.8, 0.2], [0.2, 0.8]], n_samples_per_comp)
])

# Target mixture (shifted and rotated)
theta = np.pi / 6
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
Xb_gmm = np.vstack([
    np.random.multivariate_normal([1, 1], [[1.2, 0.4], [0.4, 1.2]], n_samples_per_comp),
    np.random.multivariate_normal([4, 1], [[1, -0.4], [-0.4, 1]], n_samples_per_comp),
    np.random.multivariate_normal([2.5, 4], [[0.9, 0.3], [0.3, 0.9]], n_samples_per_comp)
]) @ R.T

# Solve with factored transport
plan_gmm, log_gmm = ot.factored_optimal_transport(
    Xa_gmm, Xb_gmm, 
    verbose=True, 
    log=True
)

# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Source data
axes[0].scatter(Xa_gmm[:, 0], Xa_gmm[:, 1], alpha=0.6, c='blue')
axes[0].set_title('Source Distribution')
axes[0].set_aspect('equal')

# Target data  
axes[1].scatter(Xb_gmm[:, 0], Xb_gmm[:, 1], alpha=0.6, c='red')
axes[1].set_title('Target Distribution')
axes[1].set_aspect('equal')

# Transport plan visualization
im = axes[2].imshow(plan_gmm, cmap='Blues', aspect='auto')
axes[2].set_xlabel('Target samples')
axes[2].set_ylabel('Source samples')
axes[2].set_title('Factored Transport Plan')
plt.colorbar(im, ax=axes[2])

plt.tight_layout()
plt.show()

if 'factorization_rank' in log_gmm:
    print(f"Effective factorization rank: {log_gmm['factorization_rank']}")

Time Series Transport

# Example with time series data
from sklearn.decomposition import PCA

# Generate time series with shared temporal patterns
t = np.linspace(0, 10, 100)
n_series_source = 200
n_series_target = 250

# Base temporal patterns
patterns = np.array([
    np.sin(t),
    np.cos(t), 
    np.sin(2*t),
    np.exp(-t/5) * np.sin(t)
]).T

# Source time series (linear combinations of patterns)
weights_source = np.random.exponential(1, (n_series_source, 4))
Xa_ts = weights_source @ patterns.T + 0.1 * np.random.randn(n_series_source, len(t))

# Target time series (shifted patterns)
weights_target = np.random.exponential(1.2, (n_series_target, 4))
patterns_shifted = np.roll(patterns, 5, axis=0)  # Temporal shift
Xb_ts = weights_target @ patterns_shifted.T + 0.1 * np.random.randn(n_series_target, len(t))

# Apply PCA preprocessing to enhance structure
pca = PCA(n_components=10)
Xa_ts_pca = pca.fit_transform(Xa_ts)
Xb_ts_pca = pca.transform(Xb_ts)

# Factored transport on time series
plan_ts = ot.factored_optimal_transport(Xa_ts_pca, Xb_ts_pca, verbose=True)

# Visualize sample time series and their transport
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Sample source time series
for i in range(5):
    axes[0, 0].plot(t, Xa_ts[i], alpha=0.7)
axes[0, 0].set_title('Sample Source Time Series')
axes[0, 0].set_xlabel('Time')

# Sample target time series
for i in range(5):
    axes[0, 1].plot(t, Xb_ts[i], alpha=0.7)
axes[0, 1].set_title('Sample Target Time Series')
axes[0, 1].set_xlabel('Time')

# PCA representation
axes[1, 0].scatter(Xa_ts_pca[:, 0], Xa_ts_pca[:, 1], alpha=0.6, label='Source')
axes[1, 0].scatter(Xb_ts_pca[:, 0], Xb_ts_pca[:, 1], alpha=0.6, label='Target')
axes[1, 0].set_xlabel('PC1')
axes[1, 0].set_ylabel('PC2')
axes[1, 0].set_title('PCA Representation')
axes[1, 0].legend()

# Transport plan sparsity pattern
axes[1, 1].spy(plan_ts > 1e-6, markersize=0.1)
axes[1, 1].set_title('Transport Plan Sparsity')
axes[1, 1].set_xlabel('Target series')
axes[1, 1].set_ylabel('Source series')

plt.tight_layout()
plt.show()

Import Statements

import ot
from ot import factored_optimal_transport
from ot.factored import factored_optimal_transport

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json