docs
0
# Entropic Regularized Transport
1
2
The `ot.bregman` module provides algorithms for solving entropic regularized optimal transport problems using Bregman projections. The Sinkhorn algorithm and its variants are the core methods, offering computational advantages over exact linear programming approaches while maintaining good approximation quality.
3
4
## Core Sinkhorn Algorithms
5
6
### Standard Sinkhorn Algorithm
7
8
```python { .api }
9
def ot.bregman.sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, warn=True, warmstart=None, **kwargs):
10
"""
11
Solve entropic regularized optimal transport using Sinkhorn algorithm.
12
13
The Sinkhorn algorithm iteratively projects onto marginal constraints to find
14
the optimal transport plan for the entropy-regularized problem:
15
min <P,M> + reg * KL(P|K) subject to P1=a, P^T1=b
16
where K = exp(-M/reg).
17
18
Parameters:
19
- a: array-like, shape (n_samples_source,)
20
Source distribution (histogram). Must be positive and sum to 1.
21
- b: array-like, shape (n_samples_target,)
22
Target distribution (histogram). Must be positive and sum to 1.
23
- M: array-like, shape (n_samples_source, n_samples_target)
24
Ground cost matrix.
25
- reg: float
26
Regularization parameter (>0). Lower values give solutions closer to EMD.
27
- method: str, default='sinkhorn'
28
Algorithm variant. Options: 'sinkhorn', 'sinkhorn_log', 'sinkhorn_stabilized',
29
'sinkhorn_epsilon_scaling', 'greenkhorn', 'screenkhorn'
30
- numItermax: int, default=1000
31
Maximum number of iterations.
32
- stopThr: float, default=1e-9
33
Convergence threshold on marginal difference.
34
- verbose: bool, default=False
35
Print iteration information.
36
- log: bool, default=False
37
Return optimization log with convergence details.
38
- warn: bool, default=True
39
Warn if algorithm doesn't converge.
40
- warmstart: tuple, default=None
41
Tuple (u, v) of dual variables for warm start initialization.
42
43
Returns:
44
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
45
Entropic optimal transport plan.
46
- log: dict (if log=True)
47
Contains 'err': convergence errors, 'niter': iterations used,
48
'u': source scaling, 'v': target scaling.
49
"""
50
51
def ot.bregman.sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs):
52
"""
53
Solve entropic regularized OT and return transport cost only.
54
55
More efficient than sinkhorn() when only the optimal value is needed.
56
57
Parameters: Same as sinkhorn()
58
59
Returns:
60
- cost: float
61
Entropic regularized transport cost.
62
- log: dict (if log=True)
63
"""
64
65
def ot.bregman.sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
66
"""
67
Sinkhorn-Knopp algorithm for entropic optimal transport.
68
69
Classic formulation using multiplicative updates with diagonal scaling matrices.
70
71
Parameters: Similar to sinkhorn()
72
73
Returns:
74
- transport_plan: ndarray
75
- log: dict (if log=True)
76
"""
77
```
78
79
### Advanced Sinkhorn Variants
80
81
```python { .api }
82
def ot.bregman.sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
83
"""
84
Sinkhorn algorithm in log-domain for numerical stability.
85
86
Performs computations in log space to avoid numerical overflow/underflow
87
issues when regularization parameter is small or cost matrix has large values.
88
89
Parameters: Same as sinkhorn()
90
91
Returns:
92
- transport_plan: ndarray
93
- log: dict (if log=True)
94
"""
95
96
def ot.bregman.sinkhorn_stabilized(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, tau=1e3, **kwargs):
97
"""
98
Stabilized Sinkhorn algorithm with absorption technique.
99
100
Uses tau-absorption to prevent numerical overflow while maintaining
101
precision. Automatically switches between normal and log computations.
102
103
Parameters:
104
- Additional parameter:
105
- tau: float, default=1e3
106
Absorption threshold. When scaling factors exceed tau, algorithm
107
absorbs them into the dual variables.
108
109
Returns:
110
- transport_plan: ndarray
111
- log: dict (if log=True)
112
"""
113
114
def ot.bregman.sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, stopThr=1e-9, verbose=False, log=False, **kwargs):
115
"""
116
Epsilon-scaling Sinkhorn for better convergence with small regularization.
117
118
Starts with large regularization parameter and progressively decreases it
119
to the target value, using warm-start between scales.
120
121
Parameters:
122
- epsilon0: float, default=1e4
123
Initial (large) regularization parameter.
124
- numInnerItermax: int, default=100
125
Maximum iterations per epsilon scale.
126
- Other parameters same as sinkhorn()
127
128
Returns:
129
- transport_plan: ndarray
130
- log: dict (if log=True)
131
"""
132
133
def ot.bregman.greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
134
"""
135
Greenkhorn algorithm for sparse optimal transport.
136
137
Coordinate-wise variant of Sinkhorn that updates one row/column at a time,
138
leading to sparse solutions suitable for large-scale problems.
139
140
Parameters: Same as sinkhorn() with typically larger numItermax
141
142
Returns:
143
- transport_plan: ndarray (often sparse)
144
- log: dict (if log=True)
145
"""
146
147
def ot.bregman.screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
148
"""
149
Screenkhorn algorithm for large-scale optimal transport.
150
151
Uses screening techniques to identify and ignore negligible entries in the
152
transport matrix, significantly reducing computational cost for large problems.
153
154
Parameters:
155
- ns_budget: int, optional
156
Maximum number of active source samples.
157
- nt_budget: int, optional
158
Maximum number of active target samples.
159
- uniform: bool, default=False
160
Use uniform sampling for screening.
161
- restricted: bool, default=True
162
Use restricted Sinkhorn on screened samples.
163
- maxiter: int, default=10000
164
- maxfun: int, default=10000
165
Maximum function evaluations.
166
- pgtol: float, default=1e-09
167
Projected gradient tolerance.
168
169
Returns:
170
- transport_plan: ndarray
171
- log: dict (if log=True)
172
"""
173
```
174
175
## Barycenter Algorithms
176
177
### Standard Barycenters
178
179
```python { .api }
180
def ot.bregman.barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs):
181
"""
182
Compute Wasserstein barycenter using entropic regularization.
183
184
Solves the multi-marginal optimal transport problem to find the barycenter
185
that minimizes the sum of regularized transport costs to all input distributions.
186
187
Parameters:
188
- A: array-like, shape (n_samples, n_distributions)
189
Input distributions as columns of matrix A.
190
- M: array-like, shape (n_samples, n_samples)
191
Ground cost matrix on barycenter support.
192
- reg: float
193
Entropic regularization parameter.
194
- weights: array-like, shape (n_distributions,), optional
195
Weights for barycenter combination. Default is uniform.
196
- method: str, default="sinkhorn"
197
Algorithm to use for transport computation.
198
- numItermax: int, default=10000
199
Maximum iterations for barycenter computation.
200
- stopThr: float, default=1e-4
201
Convergence threshold.
202
- verbose: bool, default=False
203
- log: bool, default=False
204
205
Returns:
206
- barycenter: ndarray, shape (n_samples,)
207
Wasserstein barycenter distribution.
208
- log: dict (if log=True)
209
Contains convergence information and transport plans.
210
"""
211
212
def ot.bregman.barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
213
"""
214
Compute barycenter using Sinkhorn algorithm (alternative implementation).
215
216
Parameters: Same as barycenter()
217
218
Returns:
219
- barycenter: ndarray
220
- log: dict (if log=True)
221
"""
222
223
def ot.bregman.barycenter_stabilized(A, M, reg, tau=1e3, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
224
"""
225
Compute barycenter using stabilized Sinkhorn algorithm.
226
227
Parameters:
228
- tau: float, default=1e3
229
Stabilization parameter.
230
- Other parameters same as barycenter()
231
232
Returns:
233
- barycenter: ndarray
234
- log: dict (if log=True)
235
"""
236
237
def ot.bregman.barycenter_debiased(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
238
"""
239
Compute debiased Wasserstein barycenter.
240
241
Applies debiasing correction to reduce bias introduced by entropic regularization.
242
243
Parameters: Same as barycenter()
244
245
Returns:
246
- barycenter: ndarray
247
- log: dict (if log=True)
248
"""
249
```
250
251
### Free Support Barycenters
252
253
```python { .api }
254
def ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=False, **kwargs):
255
"""
256
Compute free-support Wasserstein barycenter using Sinkhorn algorithm.
257
258
Optimizes both barycenter weights and support locations simultaneously,
259
unlike fixed-support methods that only optimize weights.
260
261
Parameters:
262
- measures_locations: list of arrays
263
Support points for each input measure.
264
- measures_weights: list of arrays
265
Weights for each input measure.
266
- X_init: array-like, shape (k, d)
267
Initial barycenter support points.
268
- reg: float
269
Entropic regularization parameter.
270
- b: array-like, shape (k,), optional
271
Barycenter weights (optimized if None).
272
- weights: array-like, optional
273
Weights for combining input measures.
274
- numItermax: int, default=100
275
- stopThr: float, default=1e-7
276
- verbose: bool, default=False
277
- log: bool, default=False
278
279
Returns:
280
- X: ndarray, shape (k, d)
281
Optimal barycenter support points.
282
- b: ndarray, shape (k,)
283
Optimal barycenter weights.
284
- log: dict (if log=True)
285
"""
286
287
def ot.bregman.jcpot_barycenter(Xs, Ys, Ps, lambdas, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs):
288
"""
289
Compute Joint Characteristic-Optimal-Transport (JCPOT) barycenter.
290
291
Specialized barycenter for joint distributions, commonly used in
292
domain adaptation scenarios.
293
294
Parameters:
295
- Xs: list of arrays
296
Source feature matrices for each domain.
297
- Ys: list of arrays
298
Target feature matrices for each domain.
299
- Ps: list of arrays
300
Initial transport plans for each domain.
301
- lambdas: array-like
302
Weights for domain combination.
303
- reg: float
304
Entropic regularization.
305
- metric: str, default='sqeuclidean'
306
Ground metric for cost computation.
307
- numItermax: int, default=100
308
- stopThr: float, default=1e-6
309
- verbose: bool, default=False
310
- log: bool, default=False
311
312
Returns:
313
- X_barycenter: ndarray
314
Barycenter in source space.
315
- Y_barycenter: ndarray
316
Barycenter in target space.
317
- log: dict (if log=True)
318
"""
319
```
320
321
## Convolutional Barycenters
322
323
```python { .api }
324
def ot.bregman.convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
325
"""
326
Compute 2D convolutional Wasserstein barycenter.
327
328
Specialized algorithm for 2D images using convolutional structure for
329
efficiency. Exploits translation invariance of the ground metric.
330
331
Parameters:
332
- A: array-like, shape (h, w, n_images)
333
Stack of 2D images/distributions.
334
- reg: float
335
Entropic regularization parameter.
336
- weights: array-like, shape (n_images,), optional
337
Barycenter weights.
338
- numItermax: int, default=10000
339
- stopThr: float, default=1e-9
340
- stabThr: float, default=1e-30
341
Numerical stability threshold.
342
- verbose: bool, default=False
343
- log: bool, default=False
344
345
Returns:
346
- barycenter: ndarray, shape (h, w)
347
2D convolutional Wasserstein barycenter.
348
- log: dict (if log=True)
349
"""
350
351
def ot.bregman.convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
352
"""
353
Compute debiased 2D convolutional Wasserstein barycenter.
354
355
Applies debiasing to reduce regularization bias in convolutional barycenters.
356
357
Parameters: Same as convolutional_barycenter2d()
358
359
Returns:
360
- barycenter: ndarray, shape (h, w)
361
- log: dict (if log=True)
362
"""
363
```
364
365
## Empirical Methods
366
367
```python { .api }
368
def ot.bregman.empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
369
"""
370
Compute Sinkhorn transport between empirical distributions.
371
372
Convenient wrapper that computes cost matrix from sample coordinates
373
and applies Sinkhorn algorithm.
374
375
Parameters:
376
- X_s: array-like, shape (n_samples_source, n_features)
377
Source samples.
378
- X_t: array-like, shape (n_samples_target, n_features)
379
Target samples.
380
- reg: float
381
Entropic regularization parameter.
382
- a: array-like, shape (n_samples_source,), optional
383
Source sample weights. Default is uniform.
384
- b: array-like, shape (n_samples_target,), optional
385
Target sample weights. Default is uniform.
386
- metric: str, default='sqeuclidean'
387
Ground metric for cost matrix computation.
388
- numItermax: int, default=10000
389
- stopThr: float, default=1e-9
390
- verbose: bool, default=False
391
- log: bool, default=False
392
393
Returns:
394
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
395
- log: dict (if log=True)
396
"""
397
398
def ot.bregman.empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
399
"""
400
Compute empirical Sinkhorn transport cost only.
401
402
Parameters: Same as empirical_sinkhorn()
403
404
Returns:
405
- cost: float
406
Empirical Sinkhorn transport cost.
407
- log: dict (if log=True)
408
"""
409
410
def ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
411
"""
412
Compute empirical Sinkhorn divergence (debiased).
413
414
Computes the Sinkhorn divergence: W_reg(a,b) - 0.5*W_reg(a,a) - 0.5*W_reg(b,b)
415
which removes the regularization bias for better approximation of Wasserstein distance.
416
417
Parameters: Same as empirical_sinkhorn()
418
419
Returns:
420
- divergence: float
421
Sinkhorn divergence value.
422
- log: dict (if log=True)
423
"""
424
425
def ot.bregman.empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False):
426
"""
427
GeomLoss-compatible implementation of empirical Sinkhorn.
428
429
Parameters: Same as empirical_sinkhorn()
430
431
Returns:
432
- cost: float
433
- log: dict (if log=True)
434
"""
435
436
def ot.bregman.geomloss(X_s, X_t, a=None, b=None, loss='sinkhorn', p=2, blur=0.1, reach=1.0, diameter=1.0, scaling=0.5, truncate=5, cost=None, kernel=None, cluster_scale=None, debias=True, potentials=False, verbose=False, backend='auto'):
437
"""
438
GeomLoss wrapper for various optimal transport losses.
439
440
Unified interface supporting Wasserstein, energy, Hausdorff and Sinkhorn losses
441
with automatic differentiation support.
442
443
Parameters:
444
- X_s: array-like, shape (n_s, d)
445
Source samples.
446
- X_t: array-like, shape (n_t, d)
447
Target samples.
448
- a: array-like, shape (n_s,), optional
449
Source weights.
450
- b: array-like, shape (n_t,), optional
451
Target weights.
452
- loss: str, default='sinkhorn'
453
Loss type: 'wasserstein', 'sinkhorn', 'energy', 'hausdorff'
454
- p: int, default=2
455
Ground metric exponent.
456
- blur: float, default=0.1
457
Regularization/smoothing parameter.
458
- reach: float, default=1.0
459
Kernel reach parameter.
460
- diameter: float, default=1.0
461
Point cloud diameter estimate.
462
- scaling: float, default=0.5
463
Multi-scale algorithm parameter.
464
- truncate: int, default=5
465
Kernel truncation parameter.
466
- cost: callable, optional
467
Custom cost function.
468
- kernel: callable, optional
469
Custom kernel function.
470
- cluster_scale: float, optional
471
Clustering scale for acceleration.
472
- debias: bool, default=True
473
Apply debiasing for Sinkhorn loss.
474
- potentials: bool, default=False
475
Return dual potentials.
476
- verbose: bool, default=False
477
- backend: str, default='auto'
478
Computation backend.
479
480
Returns:
481
- loss_value: float
482
Computed loss value.
483
- potentials: tuple (if potentials=True)
484
Dual potentials (f, g).
485
"""
486
```
487
488
## Utility Functions
489
490
```python { .api }
491
def ot.bregman.geometricBar(weights, alldistribT):
492
"""
493
Compute geometric barycenter in Bregman divergence sense.
494
495
Parameters:
496
- weights: array-like
497
Barycenter combination weights.
498
- alldistribT: array-like
499
Matrix of input distributions (columns).
500
501
Returns:
502
- barycenter: ndarray
503
Geometric barycenter.
504
"""
505
506
def ot.bregman.geometricMean(alldistribT):
507
"""
508
Compute geometric mean of distributions.
509
510
Parameters:
511
- alldistribT: array-like
512
Matrix of distributions (columns).
513
514
Returns:
515
- geometric_mean: ndarray
516
"""
517
518
def ot.bregman.projR(gamma, p):
519
"""
520
Project transport matrix onto row constraints.
521
522
Parameters:
523
- gamma: array-like
524
Transport matrix.
525
- p: array-like
526
Row marginal constraints.
527
528
Returns:
529
- projected_gamma: ndarray
530
"""
531
532
def ot.bregman.projC(gamma, q):
533
"""
534
Project transport matrix onto column constraints.
535
536
Parameters:
537
- gamma: array-like
538
Transport matrix.
539
- q: array-like
540
Column marginal constraints.
541
542
Returns:
543
- projected_gamma: ndarray
544
"""
545
546
def ot.bregman.unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):
547
"""
548
Solve optimal transport unmixing problem with regularization.
549
550
Decompose a distribution as a convex combination of dictionary atoms
551
using optimal transport as the fidelity term.
552
553
Parameters:
554
- a: array-like
555
Distribution to unmix.
556
- D: array-like
557
Dictionary of atoms (columns).
558
- M: array-like
559
Cost matrix for transport.
560
- M0: array-like
561
Cost matrix for dictionary regularization.
562
- h0: array-like
563
Prior on dictionary coefficients.
564
- reg: float
565
Transport regularization.
566
- reg0: float
567
Dictionary regularization.
568
- alpha: float
569
Fidelity vs regularization trade-off.
570
- numItermax: int, default=1000
571
- stopThr: float, default=1e-3
572
- verbose: bool, default=False
573
- log: bool, default=False
574
575
Returns:
576
- h: ndarray
577
Dictionary coefficients.
578
- log: dict (if log=True)
579
"""
580
```
581
582
## Usage Examples
583
584
### Basic Sinkhorn Algorithm
585
```python
586
import ot
587
import numpy as np
588
589
# Define distributions
590
a = np.array([0.5, 0.5])
591
b = np.array([0.3, 0.7])
592
593
# Cost matrix
594
M = np.array([[0.0, 1.0],
595
[1.0, 0.0]])
596
597
# Regularization parameter
598
reg = 0.1
599
600
# Compute regularized transport
601
plan_sinkhorn = ot.bregman.sinkhorn(a, b, M, reg)
602
cost_sinkhorn = ot.bregman.sinkhorn2(a, b, M, reg)
603
604
print("Sinkhorn plan:", plan_sinkhorn)
605
print("Sinkhorn cost:", cost_sinkhorn)
606
```
607
608
### Barycenter Computation
609
```python
610
# Multiple distributions
611
A = np.array([[0.6, 0.2, 0.4],
612
[0.4, 0.8, 0.6]]) # 3 distributions
613
614
# Cost matrix
615
M = ot.dist(np.arange(2).reshape(-1, 1))
616
617
# Regularization
618
reg = 0.05
619
620
# Compute barycenter
621
barycenter = ot.bregman.barycenter(A, M, reg)
622
print("Barycenter:", barycenter)
623
624
# With custom weights
625
weights = np.array([0.5, 0.3, 0.2])
626
weighted_barycenter = ot.bregman.barycenter(A, M, reg, weights=weights)
627
print("Weighted barycenter:", weighted_barycenter)
628
```
629
630
### Empirical Sinkhorn
631
```python
632
# Generate sample data
633
np.random.seed(42)
634
X_s = np.random.randn(100, 2)
635
X_t = np.random.randn(80, 2) + 1
636
637
# Regularization
638
reg = 0.1
639
640
# Compute empirical transport
641
plan = ot.bregman.empirical_sinkhorn(X_s, X_t, reg)
642
cost = ot.bregman.empirical_sinkhorn2(X_s, X_t, reg)
643
644
print("Empirical transport cost:", cost)
645
print("Transport plan shape:", plan.shape)
646
647
# Sinkhorn divergence (debiased)
648
divergence = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, reg)
649
print("Sinkhorn divergence:", divergence)
650
```
651
652
### Stabilized Sinkhorn for Small Regularization
653
```python
654
# Small regularization parameter
655
reg_small = 1e-3
656
657
# Use stabilized version to avoid numerical issues
658
plan_stable = ot.bregman.sinkhorn_stabilized(a, b, M, reg_small, verbose=True)
659
print("Stabilized Sinkhorn plan:", plan_stable)
660
661
# Or use epsilon scaling
662
plan_eps = ot.bregman.sinkhorn_epsilon_scaling(a, b, M, reg_small, verbose=True)
663
print("Epsilon-scaled plan:", plan_eps)
664
```
665
666
The `ot.bregman` module provides the most widely used algorithms in computational optimal transport, offering a good balance between computational efficiency and solution quality through entropic regularization. The Sinkhorn algorithm and its variants are particularly popular for large-scale applications and differentiable optimal transport.