Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The ot.sliced module provides efficient approximation algorithms for computing Wasserstein distances in high dimensions using random projections. These methods scale linearly with the number of samples and are particularly effective for high-dimensional data where exact optimal transport becomes computationally prohibitive.
def ot.sliced.sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
"""
Compute Sliced Wasserstein distance between two empirical distributions.
Approximates the Wasserstein distance by averaging 1D Wasserstein distances
over multiple random projections. The method exploits the fact that 1D
optimal transport has a closed-form solution via sorting.
Parameters:
- X_s: array-like, shape (n_samples_source, n_features)
Source samples in d-dimensional space.
- X_t: array-like, shape (n_samples_target, n_features)
Target samples in d-dimensional space.
- a: array-like, shape (n_samples_source,), optional
Weights for source samples. If None, assumes uniform weights.
- b: array-like, shape (n_samples_target,), optional
Weights for target samples. If None, assumes uniform weights.
- n_projections: int, default=50
Number of random projections to average over. More projections
give better approximation but increase computation time.
- p: int, default=2
Order of the Wasserstein distance (typically 1 or 2).
- projections: array-like, shape (n_projections, n_features), optional
Custom projection directions. If None, uses random projections
sampled uniformly from the unit sphere.
- seed: int, optional
Random seed for reproducible projection generation.
- log: bool, default=False
Return additional information including individual projection results.
Returns:
- sliced_distance: float
Approximated Wasserstein distance using sliced projections.
- log: dict (if log=True)
Contains 'projections': projection directions used,
'projected_distances': 1D distances for each projection.
Example:
X_s = np.random.randn(100, 10) # 100 samples in 10D
X_t = np.random.randn(80, 10) # 80 samples in 10D
sw_dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=100)
"""
def ot.sliced.max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
"""
Compute Max-Sliced Wasserstein distance using adversarial projections.
Instead of averaging over random projections, finds the projection direction
that maximizes the 1D Wasserstein distance, providing a different
approximation with theoretical guarantees.
Parameters:
- X_s: array-like, shape (n_samples_source, n_features)
Source samples.
- X_t: array-like, shape (n_samples_target, n_features)
Target samples.
- a: array-like, shape (n_samples_source,), optional
Source weights.
- b: array-like, shape (n_samples_target,), optional
Target weights.
- n_projections: int, default=50
Number of projection directions to try for finding maximum.
- p: int, default=2
Wasserstein distance order.
- projections: array-like, optional
Initial projection directions to consider.
- seed: int, optional
Random seed.
- log: bool, default=False
Return optimization details.
Returns:
- max_sliced_distance: float
Maximum 1D Wasserstein distance over all considered projections.
- log: dict (if log=True)
Contains 'max_projection': optimal projection direction,
'all_distances': distances for all tested projections.
"""def ot.sliced.sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
"""
Compute Sliced Wasserstein distance on the unit sphere.
Specialized version for data that lives on the unit sphere (e.g., directional
data, normalized features). Uses geodesic distances and spherical projections.
Parameters:
- X_s: array-like, shape (n_samples_source, n_features)
Source samples on unit sphere (assumed to be normalized).
- X_t: array-like, shape (n_samples_target, n_features)
Target samples on unit sphere.
- a: array-like, shape (n_samples_source,), optional
Source weights.
- b: array-like, shape (n_samples_target,), optional
Target weights.
- n_projections: int, default=50
Number of great circle projections.
- seed: int, optional
Random seed for projection generation.
- log: bool, default=False
Return detailed results.
Returns:
- spherical_sw_distance: float
Sliced Wasserstein distance on the sphere.
- log: dict (if log=True)
Contains projection information and individual distances.
"""
def ot.sliced.sliced_wasserstein_sphere_unif(X_s, n_projections=50, seed=None, log=False):
"""
Compute Sliced Wasserstein distance between samples and uniform distribution on sphere.
Efficient computation when comparing empirical distribution to the uniform
distribution on the unit sphere, which has known properties.
Parameters:
- X_s: array-like, shape (n_samples, n_features)
Source samples on unit sphere.
- n_projections: int, default=50
Number of projections to use.
- seed: int, optional
Random seed.
- log: bool, default=False
Returns:
- distance_to_uniform: float
Sliced Wasserstein distance to uniform distribution on sphere.
- log: dict (if log=True)
"""def ot.sliced.get_random_projections(d, n_projections, seed=None, type_as=None):
"""
Generate random projection directions on the unit sphere.
Creates uniformly distributed random unit vectors for use as projection
directions in sliced Wasserstein computations.
Parameters:
- d: int
Dimension of the ambient space.
- n_projections: int
Number of projection directions to generate.
- seed: int, optional
Random seed for reproducible generation.
- type_as: array-like, optional
Reference array for determining output type and backend.
Returns:
- projections: ndarray, shape (n_projections, d)
Random unit vectors uniformly distributed on the unit sphere.
Each row is a normalized projection direction.
Example:
# Generate 100 random projections in 5D space
projections = ot.sliced.get_random_projections(5, 100, seed=42)
print(projections.shape) # (100, 5)
print(np.allclose(np.linalg.norm(projections, axis=1), 1.0)) # True
"""Sliced Wasserstein methods offer significant computational advantages:
The approximation quality depends on:
import ot
import numpy as np
# Generate high-dimensional sample data
np.random.seed(42)
d = 50 # Dimension
n_s, n_t = 200, 150
# Source and target samples
X_s = np.random.randn(n_s, d)
X_t = np.random.randn(n_t, d) + 1 # Shifted distribution
# Compute sliced Wasserstein distance
n_proj = 100
sw_distance = ot.sliced.sliced_wasserstein_distance(
X_s, X_t, n_projections=n_proj, seed=42
)
print(f"Sliced Wasserstein distance: {sw_distance:.4f}")
# Compare with different numbers of projections
projections_to_try = [10, 50, 100, 200]
for n_proj in projections_to_try:
dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=n_proj)
print(f"n_projections={n_proj}: distance={dist:.4f}")# Compute max-sliced distance for comparison
max_sw_distance = ot.sliced.max_sliced_wasserstein_distance(
X_s, X_t, n_projections=100, seed=42
)
print(f"Max-Sliced Wasserstein distance: {max_sw_distance:.4f}")
print(f"Ratio (max/average): {max_sw_distance/sw_distance:.2f}")# Use custom projection directions
custom_projections = ot.sliced.get_random_projections(d, 50, seed=123)
# Compute distance with custom projections
sw_custom = ot.sliced.sliced_wasserstein_distance(
X_s, X_t, projections=custom_projections
)
print(f"Custom projections distance: {sw_custom:.4f}")
# Get detailed results
sw_detailed = ot.sliced.sliced_wasserstein_distance(
X_s, X_t, n_projections=20, log=True, seed=42
)
print("Detailed results:")
print(f"Distance: {sw_detailed[0]:.4f}")
print(f"Individual projection distances (first 5): {sw_detailed[1]['projected_distances'][:5]}")# Create weighted samples
a = np.random.exponential(1.0, n_s)
a = a / np.sum(a) # Normalize to sum to 1
b = np.random.exponential(1.5, n_t)
b = b / np.sum(b)
# Compute weighted sliced Wasserstein
sw_weighted = ot.sliced.sliced_wasserstein_distance(
X_s, X_t, a=a, b=b, n_projections=100
)
print(f"Weighted Sliced Wasserstein: {sw_weighted:.4f}")# Generate data on unit sphere
X_s_sphere = np.random.randn(100, d)
X_s_sphere = X_s_sphere / np.linalg.norm(X_s_sphere, axis=1, keepdims=True)
X_t_sphere = np.random.randn(80, d)
X_t_sphere = X_t_sphere / np.linalg.norm(X_t_sphere, axis=1, keepdims=True)
# Compute spherical sliced Wasserstein
sw_sphere = ot.sliced.sliced_wasserstein_sphere(
X_s_sphere, X_t_sphere, n_projections=100
)
print(f"Spherical Sliced Wasserstein: {sw_sphere:.4f}")
# Distance to uniform distribution on sphere
sw_unif = ot.sliced.sliced_wasserstein_sphere_unif(
X_s_sphere, n_projections=100
)
print(f"Distance to uniform on sphere: {sw_unif:.4f}")import time
# Compare computational time with exact methods for small problem
n_small = 50
X_s_small = np.random.randn(n_small, 2) # 2D for exact method
X_t_small = np.random.randn(n_small, 2)
# Exact EMD
tic = time.time()
M = ot.dist(X_s_small, X_t_small)
a_unif = ot.unif(n_small)
b_unif = ot.unif(n_small)
emd_distance = ot.emd2(a_unif, b_unif, M)
emd_time = time.time() - tic
# Sliced Wasserstein
tic = time.time()
sw_distance = ot.sliced.sliced_wasserstein_distance(X_s_small, X_t_small)
sw_time = time.time() - tic
print(f"EMD distance: {emd_distance:.4f} (time: {emd_time:.4f}s)")
print(f"Sliced W distance: {sw_distance:.4f} (time: {sw_time:.4f}s)")
print(f"Speedup: {emd_time/sw_time:.1f}x")# Study convergence with number of projections
projections_range = np.logspace(1, 3, 10).astype(int) # From 10 to 1000
distances = []
for n_proj in projections_range:
dist = ot.sliced.sliced_wasserstein_distance(
X_s, X_t, n_projections=n_proj, seed=42
)
distances.append(dist)
print("Convergence analysis:")
for n_proj, dist in zip(projections_range, distances):
print(f"n_projections={n_proj:4d}: distance={dist:.6f}")
# Estimate convergence
final_distance = distances[-1]
print(f"\nApproximate converged value: {final_distance:.6f}")# Compare p=1 and p=2 distances
p_values = [1, 2]
for p in p_values:
sw_p = ot.sliced.sliced_wasserstein_distance(
X_s, X_t, p=p, n_projections=100, seed=42
)
print(f"Sliced W_{p} distance: {sw_p:.4f}")Sliced Wasserstein is particularly effective for:
Use sliced methods when:
The ot.sliced module provides essential tools for scalable optimal transport in high dimensions, offering practical algorithms that maintain theoretical guarantees while dramatically reducing computational requirements.
Install with Tessl CLI
npx tessl i tessl/pypi-potdocs