docs
0
# POT: Python Optimal Transport
1
2
A comprehensive Python library providing solvers for optimization problems related to Optimal Transport for signal, image processing, and machine learning. POT offers numerous algorithms including linear OT with network simplex solver, entropic regularization with Sinkhorn algorithms, Wasserstein barycenters, Gromov-Wasserstein distances, unbalanced and partial optimal transport variants, sliced Wasserstein distances, and stochastic solvers for large-scale problems.
3
4
## Package Information
5
6
- **Package Name**: POT
7
- **Language**: Python
8
- **Installation**: `pip install POT`
9
- **Version**: 0.9.5
10
11
## Core Imports
12
13
```python
14
import ot
15
```
16
17
Import specific functions:
18
19
```python
20
from ot import emd, emd2, sinkhorn, sinkhorn2, gromov_wasserstein
21
```
22
23
Import submodules:
24
25
```python
26
import ot.lp
27
import ot.bregman
28
import ot.gromov
29
import ot.unbalanced
30
```
31
32
## Basic Usage
33
34
```python
35
import ot
36
import numpy as np
37
38
# Define source and target distributions
39
a = np.array([1.0, 0.5]) # Source distribution (must sum to 1)
40
b = np.array([0.5, 1.0]) # Target distribution (must sum to 1)
41
42
# Define cost matrix
43
M = np.array([[0.5, 2.0],
44
[1.0, 0.5]])
45
46
# Compute optimal transport plan using exact solver
47
plan = ot.emd(a, b, M)
48
print("Transport plan:", plan)
49
50
# Compute transport cost
51
cost = ot.emd2(a, b, M)
52
print("Transport cost:", cost)
53
54
# Compute using entropic regularization (Sinkhorn)
55
reg = 0.1
56
plan_sinkhorn = ot.sinkhorn(a, b, M, reg)
57
cost_sinkhorn = ot.sinkhorn2(a, b, M, reg)
58
print("Sinkhorn plan:", plan_sinkhorn)
59
print("Sinkhorn cost:", cost_sinkhorn)
60
```
61
62
## Architecture
63
64
POT is organized into specialized modules covering different aspects of optimal transport:
65
66
- **Linear Programming** (`ot.lp`): Exact optimal transport solvers using network simplex and linear programming
67
- **Bregman Projections** (`ot.bregman`): Entropic regularization methods including Sinkhorn algorithms and variants
68
- **Gromov-Wasserstein** (`ot.gromov`): Structured optimal transport for comparing metric spaces
69
- **Unbalanced Transport** (`ot.unbalanced`): Methods for unbalanced optimal transport problems
70
- **Domain Adaptation** (`ot.da`): Transport-based methods for domain adaptation in machine learning
71
- **Backend System** (`ot.backend`): Multi-framework support (NumPy, PyTorch, JAX, TensorFlow, CuPy)
72
73
The library provides both high-level functions directly in the main `ot` module and specialized implementations in submodules, enabling users to choose the appropriate level of granularity for their applications.
74
75
## Capabilities
76
77
### Linear Programming Solvers
78
79
Exact optimal transport computation using the Earth Mover's Distance (EMD) with network simplex solver, supporting 1D specialized solvers and free support barycenters.
80
81
```python { .api }
82
def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
83
"""
84
Solve the Earth Mover's Distance problem and return optimal transport plan.
85
86
Parameters:
87
- a: array-like, source distribution (histogram)
88
- b: array-like, target distribution (histogram)
89
- M: array-like, cost matrix
90
- numItermax: int, maximum number of iterations
91
- log: bool, return optimization log
92
- center_dual: bool, center dual potentials
93
- numThreads: int, number of threads for parallel computation
94
95
Returns:
96
- transport plan matrix or (plan, log) if log=True
97
"""
98
99
def emd2(a, b, M, processes=1, numItermax=100000, log=False, return_matrix=False, center_dual=True, numThreads=1):
100
"""
101
Solve EMD and return transport cost only.
102
103
Returns:
104
- transport cost (scalar) or (cost, log) if log=True
105
"""
106
107
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1.0, dense=True, log=False):
108
"""
109
Solve 1D optimal transport problem.
110
111
Parameters:
112
- x_a, x_b: array-like, sample positions
113
- a, b: array-like, sample weights (uniform if None)
114
- metric: str, cost metric ('sqeuclidean', 'euclidean', 'cityblock', 'minkowski')
115
- p: float, exponent for Minkowski metric
116
- dense: bool, return dense transport matrix
117
- log: bool, return optimization log
118
"""
119
120
def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
121
"""
122
Compute 1D Wasserstein distance between two distributions.
123
124
Parameters:
125
- u_values, v_values: array-like, sample positions
126
- u_weights, v_weights: array-like, sample weights
127
- p: int, Wasserstein distance order
128
- require_sort: bool, whether inputs need sorting
129
"""
130
```
131
132
[Linear Programming Solvers](./linear-programming.md)
133
134
### Entropic Regularized Transport
135
136
Sinkhorn algorithm and variants for solving regularized optimal transport problems, including stabilized versions, epsilon-scaling, and specialized algorithms like Greenkhorn and Screenkhorn.
137
138
```python { .api }
139
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
140
"""
141
Solve entropic regularized optimal transport with Sinkhorn algorithm.
142
143
Parameters:
144
- a, b: array-like, source and target distributions
145
- M: array-like, cost matrix
146
- reg: float, regularization parameter
147
- method: str, algorithm variant ('sinkhorn', 'sinkhorn_log', 'sinkhorn_stabilized',
148
'sinkhorn_epsilon_scaling', 'greenkhorn', 'screenkhorn')
149
- numItermax: int, maximum iterations
150
- stopThr: float, convergence threshold
151
- verbose: bool, print information
152
- log: bool, return optimization log
153
154
Returns:
155
- transport plan matrix or (plan, log) if log=True
156
"""
157
158
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs):
159
"""
160
Compute Wasserstein barycenter of distributions.
161
162
Parameters:
163
- A: array-like, input distributions (columns)
164
- M: array-like, cost matrix
165
- reg: float, regularization parameter
166
- weights: array-like, barycenter weights
167
- method: str, algorithm ('sinkhorn', 'sinkhorn_log', etc.)
168
169
Returns:
170
- barycenter distribution or (barycenter, log) if log=True
171
"""
172
```
173
174
[Entropic Regularized Transport](./entropic-transport.md)
175
176
### Gromov-Wasserstein Distances
177
178
Structured optimal transport for comparing metric spaces, including fused variants, barycenters, entropic regularization, and advanced methods like partial and semi-relaxed formulations.
179
180
```python { .api }
181
def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
182
"""
183
Compute Gromov-Wasserstein distance between metric spaces.
184
185
Parameters:
186
- C1, C2: array-like, cost matrices for source and target spaces
187
- p, q: array-like, source and target distributions
188
- loss_fun: str or function, loss function ('square_loss', 'kl_loss')
189
- alpha: float, step size parameter
190
- armijo: bool, use Armijo line search
191
- log: bool, return optimization log
192
193
Returns:
194
- transport plan matrix or (plan, log) if log=True
195
"""
196
197
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
198
"""
199
Compute Fused Gromov-Wasserstein distance combining structure and features.
200
201
Parameters:
202
- M: array-like, feature cost matrix
203
- C1, C2: array-like, structure cost matrices
204
- Additional parameters as in gromov_wasserstein
205
206
Returns:
207
- transport plan matrix or (plan, log) if log=True
208
"""
209
210
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
211
"""
212
Compute Gromov-Wasserstein barycenter of metric spaces.
213
214
Parameters:
215
- N: int, size of barycenter space
216
- Cs: list, cost matrices of input spaces
217
- ps: list, distributions of input spaces
218
- p: array-like, barycenter distribution
219
- lambdas: array-like, barycenter weights
220
- loss_fun: str or function, loss function
221
222
Returns:
223
- barycenter cost matrix or (barycenter, log) if log=True
224
"""
225
```
226
227
[Gromov-Wasserstein Transport](./gromov-wasserstein.md)
228
229
### Unbalanced Optimal Transport
230
231
Methods for optimal transport between measures with different total masses, supporting various divergences and regularization approaches.
232
233
```python { .api }
234
def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
235
"""
236
Solve unbalanced optimal transport with KL relaxation.
237
238
Parameters:
239
- a, b: array-like, source and target distributions
240
- M: array-like, cost matrix
241
- reg: float, entropic regularization parameter
242
- reg_m: float or tuple, marginal relaxation parameter(s)
243
- method: str, algorithm variant
244
- Additional parameters as in sinkhorn
245
246
Returns:
247
- transport plan matrix or (plan, log) if log=True
248
"""
249
250
def barycenter_unbalanced(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
251
"""
252
Compute unbalanced Wasserstein barycenter.
253
254
Parameters:
255
- A: array-like, input distributions
256
- M: array-like, cost matrix
257
- reg: float, entropic regularization
258
- reg_m: float, marginal relaxation
259
- weights: array-like, barycenter weights
260
261
Returns:
262
- barycenter distribution or (barycenter, log) if log=True
263
"""
264
```
265
266
[Unbalanced Optimal Transport](./unbalanced-transport.md)
267
268
### Utility Functions and Tools
269
270
Essential utilities for optimal transport including distance computation, distribution generation, timing functions, and array operations.
271
272
```python { .api }
273
def dist(x1, x2=None, metric='sqeuclidean'):
274
"""
275
Compute distance matrix between samples.
276
277
Parameters:
278
- x1, x2: array-like, input samples
279
- metric: str, distance metric
280
281
Returns:
282
- distance matrix
283
"""
284
285
def unif(n, type_as=None):
286
"""
287
Generate uniform distribution.
288
289
Parameters:
290
- n: int, distribution size
291
- type_as: array-like, reference for array type
292
293
Returns:
294
- uniform distribution array
295
"""
296
297
def tic():
298
"""Start timer for performance measurement."""
299
300
def toc(message="Elapsed time : {} s"):
301
"""End timer and print elapsed time."""
302
303
def toq():
304
"""End timer and return elapsed time."""
305
```
306
307
[Utility Functions](./utilities.md)
308
309
### Sliced Wasserstein Distances
310
311
Efficient approximation methods using random projections for high-dimensional optimal transport, including spherical variants and max-sliced approaches.
312
313
```python { .api }
314
def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
315
"""
316
Compute sliced Wasserstein distance between empirical distributions.
317
318
Parameters:
319
- X_s, X_t: array-like, source and target samples
320
- a, b: array-like, sample weights
321
- n_projections: int, number of random projections
322
- p: int, Wasserstein distance order
323
- projections: array-like, custom projection directions
324
- seed: int, random seed
325
- log: bool, return detailed results
326
327
Returns:
328
- sliced Wasserstein distance or (distance, log) if log=True
329
"""
330
331
def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False):
332
"""
333
Compute max-sliced Wasserstein distance using adversarial projections.
334
"""
335
```
336
337
[Sliced Wasserstein Distances](./sliced-wasserstein.md)
338
339
### Domain Adaptation
340
341
Transport-based methods for machine learning domain adaptation, including label-regularized transport and various transport classes for different adaptation scenarios.
342
343
```python { .api }
344
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
345
"""
346
Solve optimal transport with label regularization using MM algorithm.
347
348
Parameters:
349
- a: array-like, source distribution
350
- labels_a: array-like, source labels
351
- b: array-like, target distribution
352
- M: array-like, cost matrix
353
- reg: float, entropic regularization
354
- eta: float, label regularization parameter
355
- numItermax: int, outer iterations
356
- numInnerItermax: int, inner iterations
357
358
Returns:
359
- transport plan matrix or (plan, log) if log=True
360
"""
361
362
class SinkhornTransport:
363
"""
364
Sinkhorn transport class for domain adaptation.
365
366
Parameters:
367
- reg_e: float, entropic regularization
368
- max_iter: int, maximum iterations
369
- tol: float, convergence tolerance
370
- verbose: bool, print information
371
- log: bool, keep optimization log
372
"""
373
374
def fit(self, Xs=None, Xt=None, ys=None, yt=None):
375
"""Fit transport from source to target."""
376
377
def transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):
378
"""Transform source samples to target domain."""
379
```
380
381
[Domain Adaptation](./domain-adaptation.md)
382
383
### Partial Optimal Transport
384
385
Methods for optimal transport with relaxed mass constraints, allowing transport of only partial mass between distributions.
386
387
```python { .api }
388
def partial_wasserstein(a, b, M, m=None, numItermax=1000000, log=False, **kwargs):
389
"""
390
Solve partial optimal transport problem.
391
392
Parameters:
393
- a, b: array-like, source and target distributions
394
- M: array-like, cost matrix
395
- m: float, fraction of mass to transport (default: min(sum(a), sum(b)))
396
- numItermax: int, maximum iterations
397
- log: bool, return optimization log
398
399
Returns:
400
- transport plan matrix or (plan, log) if log=True
401
"""
402
403
def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
404
"""
405
Solve entropic regularized partial optimal transport.
406
"""
407
```
408
409
[Partial Optimal Transport](./partial-transport.md)
410
411
### Backend System
412
413
Multi-framework backend system enabling computation with NumPy, PyTorch, JAX, TensorFlow, and CuPy for flexible deployment and GPU acceleration.
414
415
```python { .api }
416
def get_backend(*args):
417
"""
418
Get appropriate backend for input arrays.
419
420
Parameters:
421
- args: arrays to determine backend from
422
423
Returns:
424
- backend instance
425
"""
426
427
def to_numpy(*args):
428
"""
429
Convert arrays to numpy format.
430
431
Parameters:
432
- args: arrays to convert
433
434
Returns:
435
- numpy arrays
436
"""
437
438
class Backend:
439
"""Base backend class defining array operations interface."""
440
441
class NumpyBackend(Backend):
442
"""NumPy backend implementation."""
443
444
class TorchBackend(Backend):
445
"""PyTorch backend implementation."""
446
```
447
448
[Backend System](./backend-system.md)
449
450
### Advanced Methods
451
452
Specialized algorithms including smooth optimal transport, stochastic solvers for large-scale problems, low-rank methods, and Gaussian optimal transport.
453
454
```python { .api }
455
def smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
456
"""
457
Solve smooth optimal transport using dual formulation.
458
459
Parameters:
460
- a, b: array-like, source and target distributions
461
- C: array-like, cost matrix
462
- regul: Regularization instance
463
- method: str, optimization method
464
- numItermax: int, maximum iterations
465
- log: bool, return optimization log
466
467
Returns:
468
- optimal transport plan or (plan, log) if log=True
469
"""
470
471
def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=1e-3, rank=10, numItermax=100, stopThr=1e-5, log=False):
472
"""
473
Solve optimal transport using low-rank Sinkhorn algorithm.
474
475
Parameters:
476
- X_s, X_t: array-like, source and target samples
477
- a, b: array-like, sample weights
478
- reg: float, regularization parameter
479
- rank: int, rank constraint
480
- numItermax: int, maximum iterations
481
- stopThr: float, convergence threshold
482
- log: bool, return optimization log
483
484
Returns:
485
- transport plan or (plan, log) if log=True
486
"""
487
```
488
489
[Advanced Methods](./advanced-methods.md)
490
491
### Smooth Optimal Transport
492
493
Smooth optimal transport with dual and semi-dual formulations supporting KL divergence, L2 regularization, and sparsity constraints for regularized transport solutions.
494
495
```python { .api }
496
def smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
497
"""
498
Solve smooth optimal transport using dual formulation.
499
500
Parameters:
501
- a, b: array-like, source and target distributions
502
- C: array-like, cost matrix
503
- regul: Regularization, regularization instance (NegEntropy, SquaredL2, SparsityConstrained)
504
- method: str, optimization method
505
- numItermax: int, maximum iterations
506
- log: bool, return optimization log
507
508
Returns:
509
- transport plan matrix or (plan, log) if log=True
510
"""
511
512
def smooth_ot_semi_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
513
"""
514
Solve smooth optimal transport using semi-dual formulation.
515
"""
516
```
517
518
[Smooth Optimal Transport](./smooth-transport.md)
519
520
### Stochastic Solvers
521
522
Stochastic algorithms for large-scale optimal transport using SAG and SGD methods, enabling efficient computation for problems with many samples.
523
524
```python { .api }
525
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state=None):
526
"""
527
Solve entropic regularized OT with Stochastic Average Gradient algorithm.
528
529
Parameters:
530
- a, b: array-like, source and target distributions
531
- M: array-like, cost matrix
532
- reg: float, regularization parameter
533
- numItermax: int, maximum iterations
534
- lr: float, learning rate
535
- random_state: int, random seed
536
537
Returns:
538
- transport plan matrix
539
"""
540
541
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax=10000, lr=0.1, log=False):
542
"""
543
Solve entropic regularized OT using SGD on dual formulation.
544
"""
545
```
546
547
[Stochastic Solvers](./stochastic-solvers.md)
548
549
### Regularization Path
550
551
Algorithms for computing optimal transport regularization paths, exploring the full range from unregularized to highly regularized solutions.
552
553
```python { .api }
554
def regularization_path(a, b, C, reg=1e-4, itmax=50000):
555
"""
556
Compute regularization path for optimal transport.
557
558
Parameters:
559
- a, b: array-like, source and target distributions
560
- C: array-like, cost matrix
561
- reg: float, final regularization parameter
562
- itmax: int, maximum iterations
563
564
Returns:
565
- gamma_list: list of regularization parameters
566
- Pi_list: list of corresponding transport plans
567
"""
568
569
def fully_relaxed_path(a, b, C, reg=1e-4, itmax=50000):
570
"""
571
Compute fully relaxed regularization path.
572
"""
573
```
574
575
[Regularization Path](./regularization-path.md)
576
577
### Unified Solvers
578
579
High-level unified interface providing automatic algorithm selection and consistent API across different problem types and scales.
580
581
```python { .api }
582
def solve(a, b, M, reg=None, reg_type='entropy', method='auto', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
583
"""
584
General optimal transport solver with automatic method selection.
585
586
Parameters:
587
- a, b: array-like, source and target distributions
588
- M: array-like, cost matrix
589
- reg: float, regularization parameter
590
- reg_type: str, regularization type ('entropy', 'l2', 'kl', 'tv')
591
- method: str, solver method ('auto', 'emd', 'sinkhorn', etc.)
592
- numItermax: int, maximum iterations
593
- stopThr: float, convergence threshold
594
595
Returns:
596
- transport plan matrix or (plan, log) if log=True
597
"""
598
599
def solve_gromov(C1, C2, p=None, q=None, M=None, alpha=0.0, reg=None, method='auto', **kwargs):
600
"""
601
General Gromov-Wasserstein solver with automatic method selection.
602
"""
603
```
604
605
[Unified Solvers](./unified-solvers.md)
606
607
### Weak Optimal Transport
608
609
Weak optimal transport minimizing displacement variance rather than total cost, preserving local structure for shape matching applications.
610
611
```python { .api }
612
def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
613
"""
614
Solve weak optimal transport problem between empirical distributions.
615
616
Parameters:
617
- Xa, Xb: array-like, source and target samples
618
- a, b: array-like, source and target distributions
619
- verbose: bool, print optimization information
620
- log: bool, return optimization log
621
- G0: array-like, initial transport plan
622
623
Returns:
624
- transport plan matrix or (plan, log) if log=True
625
"""
626
```
627
628
[Weak Optimal Transport](./weak-transport.md)
629
630
### Factored Transport
631
632
Factored optimal transport exploiting structure for efficient large-scale computation using low-rank decompositions.
633
634
```python { .api }
635
def factored_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, **kwargs):
636
"""
637
Solve optimal transport using factored decomposition.
638
639
Parameters:
640
- Xa, Xb: array-like, source and target samples
641
- a, b: array-like, distributions
642
- verbose: bool, print information
643
- log: bool, return optimization log
644
645
Returns:
646
- transport plan matrix or (plan, log) if log=True
647
"""
648
```
649
650
[Factored Transport](./factored-transport.md)
651
652
## Types
653
654
```python { .api }
655
# Common array types accepted by POT functions
656
ArrayLike = Union[numpy.ndarray, List, Tuple]
657
658
# Backend-specific array types
659
BackendArray = Union[numpy.ndarray, torch.Tensor, jax.numpy.ndarray, tensorflow.Tensor, cupy.ndarray]
660
661
# Log dictionary returned by functions with log=True
662
LogDict = Dict[str, Union[float, int, List, numpy.ndarray]]
663
664
# Transport plan matrix type
665
TransportPlan = numpy.ndarray # Shape: (n_samples_source, n_samples_target)
666
667
# Cost matrix type
668
CostMatrix = numpy.ndarray # Shape: (n_samples_source, n_samples_target)
669
670
# Distribution vector type
671
Distribution = numpy.ndarray # Shape: (n_samples,), non-negative, typically sums to 1
672
```