CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

sliced-wasserstein.mddocs/

Sliced Wasserstein Distances

The ot.sliced module provides efficient approximation algorithms for computing Wasserstein distances in high dimensions using random projections. These methods scale linearly with the number of samples and are particularly effective for high-dimensional data where exact optimal transport becomes computationally prohibitive.

Core Sliced Wasserstein Functions

Standard Sliced Wasserstein

def ot.sliced.sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
    """
    Compute Sliced Wasserstein distance between two empirical distributions.
    
    Approximates the Wasserstein distance by averaging 1D Wasserstein distances
    over multiple random projections. The method exploits the fact that 1D
    optimal transport has a closed-form solution via sorting.
    
    Parameters:
    - X_s: array-like, shape (n_samples_source, n_features)
         Source samples in d-dimensional space.
    - X_t: array-like, shape (n_samples_target, n_features)
         Target samples in d-dimensional space.
    - a: array-like, shape (n_samples_source,), optional
         Weights for source samples. If None, assumes uniform weights.
    - b: array-like, shape (n_samples_target,), optional
         Weights for target samples. If None, assumes uniform weights.
    - n_projections: int, default=50
         Number of random projections to average over. More projections
         give better approximation but increase computation time.
    - p: int, default=2
         Order of the Wasserstein distance (typically 1 or 2).
    - projections: array-like, shape (n_projections, n_features), optional
         Custom projection directions. If None, uses random projections
         sampled uniformly from the unit sphere.
    - seed: int, optional
         Random seed for reproducible projection generation.
    - log: bool, default=False
         Return additional information including individual projection results.
    
    Returns:
    - sliced_distance: float
         Approximated Wasserstein distance using sliced projections.
    - log: dict (if log=True)
         Contains 'projections': projection directions used,
         'projected_distances': 1D distances for each projection.
    
    Example:
        X_s = np.random.randn(100, 10)  # 100 samples in 10D
        X_t = np.random.randn(80, 10)   # 80 samples in 10D
        sw_dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=100)
    """

def ot.sliced.max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
    """
    Compute Max-Sliced Wasserstein distance using adversarial projections.
    
    Instead of averaging over random projections, finds the projection direction
    that maximizes the 1D Wasserstein distance, providing a different
    approximation with theoretical guarantees.
    
    Parameters:
    - X_s: array-like, shape (n_samples_source, n_features)
         Source samples.
    - X_t: array-like, shape (n_samples_target, n_features)
         Target samples.
    - a: array-like, shape (n_samples_source,), optional
         Source weights.
    - b: array-like, shape (n_samples_target,), optional
         Target weights.
    - n_projections: int, default=50
         Number of projection directions to try for finding maximum.
    - p: int, default=2
         Wasserstein distance order.
    - projections: array-like, optional
         Initial projection directions to consider.
    - seed: int, optional
         Random seed.
    - log: bool, default=False
         Return optimization details.
    
    Returns:
    - max_sliced_distance: float
         Maximum 1D Wasserstein distance over all considered projections.
    - log: dict (if log=True)
         Contains 'max_projection': optimal projection direction,
         'all_distances': distances for all tested projections.
    """

Spherical Sliced Wasserstein

def ot.sliced.sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
    """
    Compute Sliced Wasserstein distance on the unit sphere.
    
    Specialized version for data that lives on the unit sphere (e.g., directional
    data, normalized features). Uses geodesic distances and spherical projections.
    
    Parameters:
    - X_s: array-like, shape (n_samples_source, n_features)
         Source samples on unit sphere (assumed to be normalized).
    - X_t: array-like, shape (n_samples_target, n_features)  
         Target samples on unit sphere.
    - a: array-like, shape (n_samples_source,), optional
         Source weights.
    - b: array-like, shape (n_samples_target,), optional
         Target weights.
    - n_projections: int, default=50
         Number of great circle projections.
    - seed: int, optional
         Random seed for projection generation.
    - log: bool, default=False
         Return detailed results.
    
    Returns:
    - spherical_sw_distance: float
         Sliced Wasserstein distance on the sphere.
    - log: dict (if log=True)
         Contains projection information and individual distances.
    """

def ot.sliced.sliced_wasserstein_sphere_unif(X_s, n_projections=50, seed=None, log=False):
    """
    Compute Sliced Wasserstein distance between samples and uniform distribution on sphere.
    
    Efficient computation when comparing empirical distribution to the uniform
    distribution on the unit sphere, which has known properties.
    
    Parameters:
    - X_s: array-like, shape (n_samples, n_features)
         Source samples on unit sphere.
    - n_projections: int, default=50
         Number of projections to use.
    - seed: int, optional
         Random seed.
    - log: bool, default=False
    
    Returns:
    - distance_to_uniform: float
         Sliced Wasserstein distance to uniform distribution on sphere.
    - log: dict (if log=True)
    """

Utility Functions

def ot.sliced.get_random_projections(d, n_projections, seed=None, type_as=None):
    """
    Generate random projection directions on the unit sphere.
    
    Creates uniformly distributed random unit vectors for use as projection
    directions in sliced Wasserstein computations.
    
    Parameters:
    - d: int
         Dimension of the ambient space.
    - n_projections: int
         Number of projection directions to generate.
    - seed: int, optional
         Random seed for reproducible generation.
    - type_as: array-like, optional
         Reference array for determining output type and backend.
    
    Returns:
    - projections: ndarray, shape (n_projections, d)
         Random unit vectors uniformly distributed on the unit sphere.
         Each row is a normalized projection direction.
    
    Example:
        # Generate 100 random projections in 5D space
        projections = ot.sliced.get_random_projections(5, 100, seed=42)
        print(projections.shape)  # (100, 5)
        print(np.allclose(np.linalg.norm(projections, axis=1), 1.0))  # True
    """

Computational Advantages

Scalability Benefits

Sliced Wasserstein methods offer significant computational advantages:

  • Linear Scaling: O(n log n) complexity vs O(n³) for exact methods
  • High-Dimensional Efficiency: Performance doesn't degrade significantly with dimension
  • Parallelizable: Different projections can be computed independently
  • Memory Efficient: No need to store large transport matrices

Approximation Quality

The approximation quality depends on:

  • Number of projections (more projections → better approximation)
  • Data dimension (higher dimensions often need fewer projections)
  • Distribution characteristics (smooth distributions approximate better)

Usage Examples

Basic Sliced Wasserstein

import ot
import numpy as np

# Generate high-dimensional sample data
np.random.seed(42)
d = 50  # Dimension
n_s, n_t = 200, 150

# Source and target samples
X_s = np.random.randn(n_s, d)
X_t = np.random.randn(n_t, d) + 1  # Shifted distribution

# Compute sliced Wasserstein distance
n_proj = 100
sw_distance = ot.sliced.sliced_wasserstein_distance(
    X_s, X_t, n_projections=n_proj, seed=42
)

print(f"Sliced Wasserstein distance: {sw_distance:.4f}")

# Compare with different numbers of projections
projections_to_try = [10, 50, 100, 200]
for n_proj in projections_to_try:
    dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=n_proj)
    print(f"n_projections={n_proj}: distance={dist:.4f}")

Max-Sliced Wasserstein

# Compute max-sliced distance for comparison
max_sw_distance = ot.sliced.max_sliced_wasserstein_distance(
    X_s, X_t, n_projections=100, seed=42
)

print(f"Max-Sliced Wasserstein distance: {max_sw_distance:.4f}")
print(f"Ratio (max/average): {max_sw_distance/sw_distance:.2f}")

Custom Projections

# Use custom projection directions
custom_projections = ot.sliced.get_random_projections(d, 50, seed=123)

# Compute distance with custom projections
sw_custom = ot.sliced.sliced_wasserstein_distance(
    X_s, X_t, projections=custom_projections
)

print(f"Custom projections distance: {sw_custom:.4f}")

# Get detailed results
sw_detailed = ot.sliced.sliced_wasserstein_distance(
    X_s, X_t, n_projections=20, log=True, seed=42
)

print("Detailed results:")
print(f"Distance: {sw_detailed[0]:.4f}")
print(f"Individual projection distances (first 5): {sw_detailed[1]['projected_distances'][:5]}")

Weighted Samples

# Create weighted samples
a = np.random.exponential(1.0, n_s)
a = a / np.sum(a)  # Normalize to sum to 1

b = np.random.exponential(1.5, n_t) 
b = b / np.sum(b)

# Compute weighted sliced Wasserstein
sw_weighted = ot.sliced.sliced_wasserstein_distance(
    X_s, X_t, a=a, b=b, n_projections=100
)

print(f"Weighted Sliced Wasserstein: {sw_weighted:.4f}")

Spherical Data

# Generate data on unit sphere
X_s_sphere = np.random.randn(100, d)
X_s_sphere = X_s_sphere / np.linalg.norm(X_s_sphere, axis=1, keepdims=True)

X_t_sphere = np.random.randn(80, d)
X_t_sphere = X_t_sphere / np.linalg.norm(X_t_sphere, axis=1, keepdims=True)

# Compute spherical sliced Wasserstein
sw_sphere = ot.sliced.sliced_wasserstein_sphere(
    X_s_sphere, X_t_sphere, n_projections=100
)

print(f"Spherical Sliced Wasserstein: {sw_sphere:.4f}")

# Distance to uniform distribution on sphere
sw_unif = ot.sliced.sliced_wasserstein_sphere_unif(
    X_s_sphere, n_projections=100
)

print(f"Distance to uniform on sphere: {sw_unif:.4f}")

Performance Comparison

import time

# Compare computational time with exact methods for small problem
n_small = 50
X_s_small = np.random.randn(n_small, 2)  # 2D for exact method
X_t_small = np.random.randn(n_small, 2)

# Exact EMD
tic = time.time()
M = ot.dist(X_s_small, X_t_small)
a_unif = ot.unif(n_small)
b_unif = ot.unif(n_small)
emd_distance = ot.emd2(a_unif, b_unif, M)
emd_time = time.time() - tic

# Sliced Wasserstein
tic = time.time()
sw_distance = ot.sliced.sliced_wasserstein_distance(X_s_small, X_t_small)
sw_time = time.time() - tic

print(f"EMD distance: {emd_distance:.4f} (time: {emd_time:.4f}s)")
print(f"Sliced W distance: {sw_distance:.4f} (time: {sw_time:.4f}s)")
print(f"Speedup: {emd_time/sw_time:.1f}x")

Convergence Analysis

# Study convergence with number of projections
projections_range = np.logspace(1, 3, 10).astype(int)  # From 10 to 1000
distances = []

for n_proj in projections_range:
    dist = ot.sliced.sliced_wasserstein_distance(
        X_s, X_t, n_projections=n_proj, seed=42
    )
    distances.append(dist)

print("Convergence analysis:")
for n_proj, dist in zip(projections_range, distances):
    print(f"n_projections={n_proj:4d}: distance={dist:.6f}")

# Estimate convergence
final_distance = distances[-1]
print(f"\nApproximate converged value: {final_distance:.6f}")

Different Distance Orders

# Compare p=1 and p=2 distances
p_values = [1, 2]

for p in p_values:
    sw_p = ot.sliced.sliced_wasserstein_distance(
        X_s, X_t, p=p, n_projections=100, seed=42
    )
    print(f"Sliced W_{p} distance: {sw_p:.4f}")

Applications and Use Cases

High-Dimensional Data

Sliced Wasserstein is particularly effective for:

  • Image Processing: Comparing high-dimensional image features
  • Natural Language Processing: Document embeddings and word vectors
  • Bioinformatics: Gene expression profiles and protein data
  • Machine Learning: Feature representations and latent spaces

Computational Constraints

Use sliced methods when:

  • Exact optimal transport is too slow (large n or high d)
  • Memory is limited (can't store n×n matrices)
  • Real-time applications requiring fast distance computation
  • Batch processing of many distribution pairs

Theoretical Properties

  • Consistency: Converges to true Wasserstein distance as n_projections → ∞
  • Robustness: Less sensitive to outliers than exact methods
  • Differentiability: Smooth approximation suitable for optimization

The ot.sliced module provides essential tools for scalable optimal transport in high dimensions, offering practical algorithms that maintain theoretical guarantees while dramatically reducing computational requirements.

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json