or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

sliced-wasserstein.mddocs/

0

# Sliced Wasserstein Distances

1

2

The `ot.sliced` module provides efficient approximation algorithms for computing Wasserstein distances in high dimensions using random projections. These methods scale linearly with the number of samples and are particularly effective for high-dimensional data where exact optimal transport becomes computationally prohibitive.

3

4

## Core Sliced Wasserstein Functions

5

6

### Standard Sliced Wasserstein

7

8

```python { .api }

9

def ot.sliced.sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):

10

"""

11

Compute Sliced Wasserstein distance between two empirical distributions.

12

13

Approximates the Wasserstein distance by averaging 1D Wasserstein distances

14

over multiple random projections. The method exploits the fact that 1D

15

optimal transport has a closed-form solution via sorting.

16

17

Parameters:

18

- X_s: array-like, shape (n_samples_source, n_features)

19

Source samples in d-dimensional space.

20

- X_t: array-like, shape (n_samples_target, n_features)

21

Target samples in d-dimensional space.

22

- a: array-like, shape (n_samples_source,), optional

23

Weights for source samples. If None, assumes uniform weights.

24

- b: array-like, shape (n_samples_target,), optional

25

Weights for target samples. If None, assumes uniform weights.

26

- n_projections: int, default=50

27

Number of random projections to average over. More projections

28

give better approximation but increase computation time.

29

- p: int, default=2

30

Order of the Wasserstein distance (typically 1 or 2).

31

- projections: array-like, shape (n_projections, n_features), optional

32

Custom projection directions. If None, uses random projections

33

sampled uniformly from the unit sphere.

34

- seed: int, optional

35

Random seed for reproducible projection generation.

36

- log: bool, default=False

37

Return additional information including individual projection results.

38

39

Returns:

40

- sliced_distance: float

41

Approximated Wasserstein distance using sliced projections.

42

- log: dict (if log=True)

43

Contains 'projections': projection directions used,

44

'projected_distances': 1D distances for each projection.

45

46

Example:

47

X_s = np.random.randn(100, 10) # 100 samples in 10D

48

X_t = np.random.randn(80, 10) # 80 samples in 10D

49

sw_dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=100)

50

"""

51

52

def ot.sliced.max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):

53

"""

54

Compute Max-Sliced Wasserstein distance using adversarial projections.

55

56

Instead of averaging over random projections, finds the projection direction

57

that maximizes the 1D Wasserstein distance, providing a different

58

approximation with theoretical guarantees.

59

60

Parameters:

61

- X_s: array-like, shape (n_samples_source, n_features)

62

Source samples.

63

- X_t: array-like, shape (n_samples_target, n_features)

64

Target samples.

65

- a: array-like, shape (n_samples_source,), optional

66

Source weights.

67

- b: array-like, shape (n_samples_target,), optional

68

Target weights.

69

- n_projections: int, default=50

70

Number of projection directions to try for finding maximum.

71

- p: int, default=2

72

Wasserstein distance order.

73

- projections: array-like, optional

74

Initial projection directions to consider.

75

- seed: int, optional

76

Random seed.

77

- log: bool, default=False

78

Return optimization details.

79

80

Returns:

81

- max_sliced_distance: float

82

Maximum 1D Wasserstein distance over all considered projections.

83

- log: dict (if log=True)

84

Contains 'max_projection': optimal projection direction,

85

'all_distances': distances for all tested projections.

86

"""

87

```

88

89

### Spherical Sliced Wasserstein

90

91

```python { .api }

92

def ot.sliced.sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):

93

"""

94

Compute Sliced Wasserstein distance on the unit sphere.

95

96

Specialized version for data that lives on the unit sphere (e.g., directional

97

data, normalized features). Uses geodesic distances and spherical projections.

98

99

Parameters:

100

- X_s: array-like, shape (n_samples_source, n_features)

101

Source samples on unit sphere (assumed to be normalized).

102

- X_t: array-like, shape (n_samples_target, n_features)

103

Target samples on unit sphere.

104

- a: array-like, shape (n_samples_source,), optional

105

Source weights.

106

- b: array-like, shape (n_samples_target,), optional

107

Target weights.

108

- n_projections: int, default=50

109

Number of great circle projections.

110

- seed: int, optional

111

Random seed for projection generation.

112

- log: bool, default=False

113

Return detailed results.

114

115

Returns:

116

- spherical_sw_distance: float

117

Sliced Wasserstein distance on the sphere.

118

- log: dict (if log=True)

119

Contains projection information and individual distances.

120

"""

121

122

def ot.sliced.sliced_wasserstein_sphere_unif(X_s, n_projections=50, seed=None, log=False):

123

"""

124

Compute Sliced Wasserstein distance between samples and uniform distribution on sphere.

125

126

Efficient computation when comparing empirical distribution to the uniform

127

distribution on the unit sphere, which has known properties.

128

129

Parameters:

130

- X_s: array-like, shape (n_samples, n_features)

131

Source samples on unit sphere.

132

- n_projections: int, default=50

133

Number of projections to use.

134

- seed: int, optional

135

Random seed.

136

- log: bool, default=False

137

138

Returns:

139

- distance_to_uniform: float

140

Sliced Wasserstein distance to uniform distribution on sphere.

141

- log: dict (if log=True)

142

"""

143

```

144

145

## Utility Functions

146

147

```python { .api }

148

def ot.sliced.get_random_projections(d, n_projections, seed=None, type_as=None):

149

"""

150

Generate random projection directions on the unit sphere.

151

152

Creates uniformly distributed random unit vectors for use as projection

153

directions in sliced Wasserstein computations.

154

155

Parameters:

156

- d: int

157

Dimension of the ambient space.

158

- n_projections: int

159

Number of projection directions to generate.

160

- seed: int, optional

161

Random seed for reproducible generation.

162

- type_as: array-like, optional

163

Reference array for determining output type and backend.

164

165

Returns:

166

- projections: ndarray, shape (n_projections, d)

167

Random unit vectors uniformly distributed on the unit sphere.

168

Each row is a normalized projection direction.

169

170

Example:

171

# Generate 100 random projections in 5D space

172

projections = ot.sliced.get_random_projections(5, 100, seed=42)

173

print(projections.shape) # (100, 5)

174

print(np.allclose(np.linalg.norm(projections, axis=1), 1.0)) # True

175

"""

176

```

177

178

## Computational Advantages

179

180

### Scalability Benefits

181

Sliced Wasserstein methods offer significant computational advantages:

182

183

- **Linear Scaling**: O(n log n) complexity vs O(n³) for exact methods

184

- **High-Dimensional Efficiency**: Performance doesn't degrade significantly with dimension

185

- **Parallelizable**: Different projections can be computed independently

186

- **Memory Efficient**: No need to store large transport matrices

187

188

### Approximation Quality

189

The approximation quality depends on:

190

- Number of projections (more projections → better approximation)

191

- Data dimension (higher dimensions often need fewer projections)

192

- Distribution characteristics (smooth distributions approximate better)

193

194

## Usage Examples

195

196

### Basic Sliced Wasserstein

197

```python

198

import ot

199

import numpy as np

200

201

# Generate high-dimensional sample data

202

np.random.seed(42)

203

d = 50 # Dimension

204

n_s, n_t = 200, 150

205

206

# Source and target samples

207

X_s = np.random.randn(n_s, d)

208

X_t = np.random.randn(n_t, d) + 1 # Shifted distribution

209

210

# Compute sliced Wasserstein distance

211

n_proj = 100

212

sw_distance = ot.sliced.sliced_wasserstein_distance(

213

X_s, X_t, n_projections=n_proj, seed=42

214

)

215

216

print(f"Sliced Wasserstein distance: {sw_distance:.4f}")

217

218

# Compare with different numbers of projections

219

projections_to_try = [10, 50, 100, 200]

220

for n_proj in projections_to_try:

221

dist = ot.sliced.sliced_wasserstein_distance(X_s, X_t, n_projections=n_proj)

222

print(f"n_projections={n_proj}: distance={dist:.4f}")

223

```

224

225

### Max-Sliced Wasserstein

226

```python

227

# Compute max-sliced distance for comparison

228

max_sw_distance = ot.sliced.max_sliced_wasserstein_distance(

229

X_s, X_t, n_projections=100, seed=42

230

)

231

232

print(f"Max-Sliced Wasserstein distance: {max_sw_distance:.4f}")

233

print(f"Ratio (max/average): {max_sw_distance/sw_distance:.2f}")

234

```

235

236

### Custom Projections

237

```python

238

# Use custom projection directions

239

custom_projections = ot.sliced.get_random_projections(d, 50, seed=123)

240

241

# Compute distance with custom projections

242

sw_custom = ot.sliced.sliced_wasserstein_distance(

243

X_s, X_t, projections=custom_projections

244

)

245

246

print(f"Custom projections distance: {sw_custom:.4f}")

247

248

# Get detailed results

249

sw_detailed = ot.sliced.sliced_wasserstein_distance(

250

X_s, X_t, n_projections=20, log=True, seed=42

251

)

252

253

print("Detailed results:")

254

print(f"Distance: {sw_detailed[0]:.4f}")

255

print(f"Individual projection distances (first 5): {sw_detailed[1]['projected_distances'][:5]}")

256

```

257

258

### Weighted Samples

259

```python

260

# Create weighted samples

261

a = np.random.exponential(1.0, n_s)

262

a = a / np.sum(a) # Normalize to sum to 1

263

264

b = np.random.exponential(1.5, n_t)

265

b = b / np.sum(b)

266

267

# Compute weighted sliced Wasserstein

268

sw_weighted = ot.sliced.sliced_wasserstein_distance(

269

X_s, X_t, a=a, b=b, n_projections=100

270

)

271

272

print(f"Weighted Sliced Wasserstein: {sw_weighted:.4f}")

273

```

274

275

### Spherical Data

276

```python

277

# Generate data on unit sphere

278

X_s_sphere = np.random.randn(100, d)

279

X_s_sphere = X_s_sphere / np.linalg.norm(X_s_sphere, axis=1, keepdims=True)

280

281

X_t_sphere = np.random.randn(80, d)

282

X_t_sphere = X_t_sphere / np.linalg.norm(X_t_sphere, axis=1, keepdims=True)

283

284

# Compute spherical sliced Wasserstein

285

sw_sphere = ot.sliced.sliced_wasserstein_sphere(

286

X_s_sphere, X_t_sphere, n_projections=100

287

)

288

289

print(f"Spherical Sliced Wasserstein: {sw_sphere:.4f}")

290

291

# Distance to uniform distribution on sphere

292

sw_unif = ot.sliced.sliced_wasserstein_sphere_unif(

293

X_s_sphere, n_projections=100

294

)

295

296

print(f"Distance to uniform on sphere: {sw_unif:.4f}")

297

```

298

299

### Performance Comparison

300

```python

301

import time

302

303

# Compare computational time with exact methods for small problem

304

n_small = 50

305

X_s_small = np.random.randn(n_small, 2) # 2D for exact method

306

X_t_small = np.random.randn(n_small, 2)

307

308

# Exact EMD

309

tic = time.time()

310

M = ot.dist(X_s_small, X_t_small)

311

a_unif = ot.unif(n_small)

312

b_unif = ot.unif(n_small)

313

emd_distance = ot.emd2(a_unif, b_unif, M)

314

emd_time = time.time() - tic

315

316

# Sliced Wasserstein

317

tic = time.time()

318

sw_distance = ot.sliced.sliced_wasserstein_distance(X_s_small, X_t_small)

319

sw_time = time.time() - tic

320

321

print(f"EMD distance: {emd_distance:.4f} (time: {emd_time:.4f}s)")

322

print(f"Sliced W distance: {sw_distance:.4f} (time: {sw_time:.4f}s)")

323

print(f"Speedup: {emd_time/sw_time:.1f}x")

324

```

325

326

### Convergence Analysis

327

```python

328

# Study convergence with number of projections

329

projections_range = np.logspace(1, 3, 10).astype(int) # From 10 to 1000

330

distances = []

331

332

for n_proj in projections_range:

333

dist = ot.sliced.sliced_wasserstein_distance(

334

X_s, X_t, n_projections=n_proj, seed=42

335

)

336

distances.append(dist)

337

338

print("Convergence analysis:")

339

for n_proj, dist in zip(projections_range, distances):

340

print(f"n_projections={n_proj:4d}: distance={dist:.6f}")

341

342

# Estimate convergence

343

final_distance = distances[-1]

344

print(f"\nApproximate converged value: {final_distance:.6f}")

345

```

346

347

### Different Distance Orders

348

```python

349

# Compare p=1 and p=2 distances

350

p_values = [1, 2]

351

352

for p in p_values:

353

sw_p = ot.sliced.sliced_wasserstein_distance(

354

X_s, X_t, p=p, n_projections=100, seed=42

355

)

356

print(f"Sliced W_{p} distance: {sw_p:.4f}")

357

```

358

359

## Applications and Use Cases

360

361

### High-Dimensional Data

362

Sliced Wasserstein is particularly effective for:

363

- **Image Processing**: Comparing high-dimensional image features

364

- **Natural Language Processing**: Document embeddings and word vectors

365

- **Bioinformatics**: Gene expression profiles and protein data

366

- **Machine Learning**: Feature representations and latent spaces

367

368

### Computational Constraints

369

Use sliced methods when:

370

- Exact optimal transport is too slow (large n or high d)

371

- Memory is limited (can't store n×n matrices)

372

- Real-time applications requiring fast distance computation

373

- Batch processing of many distribution pairs

374

375

### Theoretical Properties

376

- **Consistency**: Converges to true Wasserstein distance as n_projections → ∞

377

- **Robustness**: Less sensitive to outliers than exact methods

378

- **Differentiability**: Smooth approximation suitable for optimization

379

380

The `ot.sliced` module provides essential tools for scalable optimal transport in high dimensions, offering practical algorithms that maintain theoretical guarantees while dramatically reducing computational requirements.