docs
0
# Sliced Wasserstein Distances
1
2
The `ot.sliced` module provides efficient approximation algorithms for computing Wasserstein distances in high dimensions using random projections. These methods scale linearly with the number of samples and are particularly effective for high-dimensional data where exact optimal transport becomes computationally prohibitive.
3
4
## Core Sliced Wasserstein Functions
5
6
### Standard Sliced Wasserstein
7
8
```python { .api }
9
def ot.sliced.sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
10
"""
11
Compute Sliced Wasserstein distance between two empirical distributions.
12
13
Approximates the Wasserstein distance by averaging 1D Wasserstein distances
14
over multiple random projections. The method exploits the fact that 1D
15
optimal transport has a closed-form solution via sorting.
16
17
Parameters:
18
- X_s: array-like, shape (n_samples_source, n_features)
19
Source samples in d-dimensional space.
20
- X_t: array-like, shape (n_samples_target, n_features)
21
Target samples in d-dimensional space.
22
- a: array-like, shape (n_samples_source,), optional
23
Weights for source samples. If None, assumes uniform weights.
24
- b: array-like, shape (n_samples_target,), optional
25
Weights for target samples. If None, assumes uniform weights.
26
- n_projections: int, default=50
27
Number of random projections to average over. More projections
28
give better approximation but increase computation time.
29
- p: int, default=2
30
Order of the Wasserstein distance (typically 1 or 2).
31
- projections: array-like, shape (n_projections, n_features), optional
32
Custom projection directions. If None, uses random projections
33
sampled uniformly from the unit sphere.
34
- seed: int, optional
35
Random seed for reproducible projection generation.
36
- log: bool, default=False
37
Return additional information including individual projection results.
38
39
Returns:
40
- sliced_distance: float
41
Approximated Wasserstein distance using sliced projections.
42
- log: dict (if log=True)
43
Contains 'projections': projection directions used,
44
'projected_distances': 1D distances for each projection.
45
46
Example:
47
X_s = np.random.randn(100, 10) # 100 samples in 10D
48
X_t = np.random.randn(80, 10) # 80 samples in 10D
49
sw_dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=100)
50
"""
51
52
def ot.sliced.max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
53
"""
54
Compute Max-Sliced Wasserstein distance using adversarial projections.
55
56
Instead of averaging over random projections, finds the projection direction
57
that maximizes the 1D Wasserstein distance, providing a different
58
approximation with theoretical guarantees.
59
60
Parameters:
61
- X_s: array-like, shape (n_samples_source, n_features)
62
Source samples.
63
- X_t: array-like, shape (n_samples_target, n_features)
64
Target samples.
65
- a: array-like, shape (n_samples_source,), optional
66
Source weights.
67
- b: array-like, shape (n_samples_target,), optional
68
Target weights.
69
- n_projections: int, default=50
70
Number of projection directions to try for finding maximum.
71
- p: int, default=2
72
Wasserstein distance order.
73
- projections: array-like, optional
74
Initial projection directions to consider.
75
- seed: int, optional
76
Random seed.
77
- log: bool, default=False
78
Return optimization details.
79
80
Returns:
81
- max_sliced_distance: float
82
Maximum 1D Wasserstein distance over all considered projections.
83
- log: dict (if log=True)
84
Contains 'max_projection': optimal projection direction,
85
'all_distances': distances for all tested projections.
86
"""
87
```
88
89
### Spherical Sliced Wasserstein
90
91
```python { .api }
92
def ot.sliced.sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
93
"""
94
Compute Sliced Wasserstein distance on the unit sphere.
95
96
Specialized version for data that lives on the unit sphere (e.g., directional
97
data, normalized features). Uses geodesic distances and spherical projections.
98
99
Parameters:
100
- X_s: array-like, shape (n_samples_source, n_features)
101
Source samples on unit sphere (assumed to be normalized).
102
- X_t: array-like, shape (n_samples_target, n_features)
103
Target samples on unit sphere.
104
- a: array-like, shape (n_samples_source,), optional
105
Source weights.
106
- b: array-like, shape (n_samples_target,), optional
107
Target weights.
108
- n_projections: int, default=50
109
Number of great circle projections.
110
- seed: int, optional
111
Random seed for projection generation.
112
- log: bool, default=False
113
Return detailed results.
114
115
Returns:
116
- spherical_sw_distance: float
117
Sliced Wasserstein distance on the sphere.
118
- log: dict (if log=True)
119
Contains projection information and individual distances.
120
"""
121
122
def ot.sliced.sliced_wasserstein_sphere_unif(X_s, n_projections=50, seed=None, log=False):
123
"""
124
Compute Sliced Wasserstein distance between samples and uniform distribution on sphere.
125
126
Efficient computation when comparing empirical distribution to the uniform
127
distribution on the unit sphere, which has known properties.
128
129
Parameters:
130
- X_s: array-like, shape (n_samples, n_features)
131
Source samples on unit sphere.
132
- n_projections: int, default=50
133
Number of projections to use.
134
- seed: int, optional
135
Random seed.
136
- log: bool, default=False
137
138
Returns:
139
- distance_to_uniform: float
140
Sliced Wasserstein distance to uniform distribution on sphere.
141
- log: dict (if log=True)
142
"""
143
```
144
145
## Utility Functions
146
147
```python { .api }
148
def ot.sliced.get_random_projections(d, n_projections, seed=None, type_as=None):
149
"""
150
Generate random projection directions on the unit sphere.
151
152
Creates uniformly distributed random unit vectors for use as projection
153
directions in sliced Wasserstein computations.
154
155
Parameters:
156
- d: int
157
Dimension of the ambient space.
158
- n_projections: int
159
Number of projection directions to generate.
160
- seed: int, optional
161
Random seed for reproducible generation.
162
- type_as: array-like, optional
163
Reference array for determining output type and backend.
164
165
Returns:
166
- projections: ndarray, shape (n_projections, d)
167
Random unit vectors uniformly distributed on the unit sphere.
168
Each row is a normalized projection direction.
169
170
Example:
171
# Generate 100 random projections in 5D space
172
projections = ot.sliced.get_random_projections(5, 100, seed=42)
173
print(projections.shape) # (100, 5)
174
print(np.allclose(np.linalg.norm(projections, axis=1), 1.0)) # True
175
"""
176
```
177
178
## Computational Advantages
179
180
### Scalability Benefits
181
Sliced Wasserstein methods offer significant computational advantages:
182
183
- **Linear Scaling**: O(n log n) complexity vs O(n³) for exact methods
184
- **High-Dimensional Efficiency**: Performance doesn't degrade significantly with dimension
185
- **Parallelizable**: Different projections can be computed independently
186
- **Memory Efficient**: No need to store large transport matrices
187
188
### Approximation Quality
189
The approximation quality depends on:
190
- Number of projections (more projections → better approximation)
191
- Data dimension (higher dimensions often need fewer projections)
192
- Distribution characteristics (smooth distributions approximate better)
193
194
## Usage Examples
195
196
### Basic Sliced Wasserstein
197
```python
198
import ot
199
import numpy as np
200
201
# Generate high-dimensional sample data
202
np.random.seed(42)
203
d = 50 # Dimension
204
n_s, n_t = 200, 150
205
206
# Source and target samples
207
X_s = np.random.randn(n_s, d)
208
X_t = np.random.randn(n_t, d) + 1 # Shifted distribution
209
210
# Compute sliced Wasserstein distance
211
n_proj = 100
212
sw_distance = ot.sliced.sliced_wasserstein_distance(
213
X_s, X_t, n_projections=n_proj, seed=42
214
)
215
216
print(f"Sliced Wasserstein distance: {sw_distance:.4f}")
217
218
# Compare with different numbers of projections
219
projections_to_try = [10, 50, 100, 200]
220
for n_proj in projections_to_try:
221
dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=n_proj)
222
print(f"n_projections={n_proj}: distance={dist:.4f}")
223
```
224
225
### Max-Sliced Wasserstein
226
```python
227
# Compute max-sliced distance for comparison
228
max_sw_distance = ot.sliced.max_sliced_wasserstein_distance(
229
X_s, X_t, n_projections=100, seed=42
230
)
231
232
print(f"Max-Sliced Wasserstein distance: {max_sw_distance:.4f}")
233
print(f"Ratio (max/average): {max_sw_distance/sw_distance:.2f}")
234
```
235
236
### Custom Projections
237
```python
238
# Use custom projection directions
239
custom_projections = ot.sliced.get_random_projections(d, 50, seed=123)
240
241
# Compute distance with custom projections
242
sw_custom = ot.sliced.sliced_wasserstein_distance(
243
X_s, X_t, projections=custom_projections
244
)
245
246
print(f"Custom projections distance: {sw_custom:.4f}")
247
248
# Get detailed results
249
sw_detailed = ot.sliced.sliced_wasserstein_distance(
250
X_s, X_t, n_projections=20, log=True, seed=42
251
)
252
253
print("Detailed results:")
254
print(f"Distance: {sw_detailed[0]:.4f}")
255
print(f"Individual projection distances (first 5): {sw_detailed[1]['projected_distances'][:5]}")
256
```
257
258
### Weighted Samples
259
```python
260
# Create weighted samples
261
a = np.random.exponential(1.0, n_s)
262
a = a / np.sum(a) # Normalize to sum to 1
263
264
b = np.random.exponential(1.5, n_t)
265
b = b / np.sum(b)
266
267
# Compute weighted sliced Wasserstein
268
sw_weighted = ot.sliced.sliced_wasserstein_distance(
269
X_s, X_t, a=a, b=b, n_projections=100
270
)
271
272
print(f"Weighted Sliced Wasserstein: {sw_weighted:.4f}")
273
```
274
275
### Spherical Data
276
```python
277
# Generate data on unit sphere
278
X_s_sphere = np.random.randn(100, d)
279
X_s_sphere = X_s_sphere / np.linalg.norm(X_s_sphere, axis=1, keepdims=True)
280
281
X_t_sphere = np.random.randn(80, d)
282
X_t_sphere = X_t_sphere / np.linalg.norm(X_t_sphere, axis=1, keepdims=True)
283
284
# Compute spherical sliced Wasserstein
285
sw_sphere = ot.sliced.sliced_wasserstein_sphere(
286
X_s_sphere, X_t_sphere, n_projections=100
287
)
288
289
print(f"Spherical Sliced Wasserstein: {sw_sphere:.4f}")
290
291
# Distance to uniform distribution on sphere
292
sw_unif = ot.sliced.sliced_wasserstein_sphere_unif(
293
X_s_sphere, n_projections=100
294
)
295
296
print(f"Distance to uniform on sphere: {sw_unif:.4f}")
297
```
298
299
### Performance Comparison
300
```python
301
import time
302
303
# Compare computational time with exact methods for small problem
304
n_small = 50
305
X_s_small = np.random.randn(n_small, 2) # 2D for exact method
306
X_t_small = np.random.randn(n_small, 2)
307
308
# Exact EMD
309
tic = time.time()
310
M = ot.dist(X_s_small, X_t_small)
311
a_unif = ot.unif(n_small)
312
b_unif = ot.unif(n_small)
313
emd_distance = ot.emd2(a_unif, b_unif, M)
314
emd_time = time.time() - tic
315
316
# Sliced Wasserstein
317
tic = time.time()
318
sw_distance = ot.sliced.sliced_wasserstein_distance(X_s_small, X_t_small)
319
sw_time = time.time() - tic
320
321
print(f"EMD distance: {emd_distance:.4f} (time: {emd_time:.4f}s)")
322
print(f"Sliced W distance: {sw_distance:.4f} (time: {sw_time:.4f}s)")
323
print(f"Speedup: {emd_time/sw_time:.1f}x")
324
```
325
326
### Convergence Analysis
327
```python
328
# Study convergence with number of projections
329
projections_range = np.logspace(1, 3, 10).astype(int) # From 10 to 1000
330
distances = []
331
332
for n_proj in projections_range:
333
dist = ot.sliced.sliced_wasserstein_distance(
334
X_s, X_t, n_projections=n_proj, seed=42
335
)
336
distances.append(dist)
337
338
print("Convergence analysis:")
339
for n_proj, dist in zip(projections_range, distances):
340
print(f"n_projections={n_proj:4d}: distance={dist:.6f}")
341
342
# Estimate convergence
343
final_distance = distances[-1]
344
print(f"\nApproximate converged value: {final_distance:.6f}")
345
```
346
347
### Different Distance Orders
348
```python
349
# Compare p=1 and p=2 distances
350
p_values = [1, 2]
351
352
for p in p_values:
353
sw_p = ot.sliced.sliced_wasserstein_distance(
354
X_s, X_t, p=p, n_projections=100, seed=42
355
)
356
print(f"Sliced W_{p} distance: {sw_p:.4f}")
357
```
358
359
## Applications and Use Cases
360
361
### High-Dimensional Data
362
Sliced Wasserstein is particularly effective for:
363
- **Image Processing**: Comparing high-dimensional image features
364
- **Natural Language Processing**: Document embeddings and word vectors
365
- **Bioinformatics**: Gene expression profiles and protein data
366
- **Machine Learning**: Feature representations and latent spaces
367
368
### Computational Constraints
369
Use sliced methods when:
370
- Exact optimal transport is too slow (large n or high d)
371
- Memory is limited (can't store n×n matrices)
372
- Real-time applications requiring fast distance computation
373
- Batch processing of many distribution pairs
374
375
### Theoretical Properties
376
- **Consistency**: Converges to true Wasserstein distance as n_projections → ∞
377
- **Robustness**: Less sensitive to outliers than exact methods
378
- **Differentiability**: Smooth approximation suitable for optimization
379
380
The `ot.sliced` module provides essential tools for scalable optimal transport in high dimensions, offering practical algorithms that maintain theoretical guarantees while dramatically reducing computational requirements.