docs
0
# Unbalanced Optimal Transport
1
2
The `ot.unbalanced` module provides algorithms for unbalanced optimal transport, where the marginal constraints are relaxed allowing different total masses between source and target distributions. This is particularly useful for applications involving data with outliers, noise, or when comparing distributions with naturally different masses.
3
4
## Core Unbalanced Methods
5
6
### Sinkhorn-based Unbalanced Transport
7
8
```python { .api }
9
def ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
10
"""
11
Solve unbalanced optimal transport using Sinkhorn algorithm with KL relaxation.
12
13
Solves the unbalanced optimal transport problem:
14
min <P,M> + reg*KL(P|K) + reg_m*KL(P1|a) + reg_m*KL(P^T1|b)
15
where the marginal constraints are relaxed using KL divergences.
16
17
Parameters:
18
- a: array-like, shape (n_samples_source,)
19
Source distribution. Need not sum to 1.
20
- b: array-like, shape (n_samples_target,)
21
Target distribution. Need not sum to 1.
22
- M: array-like, shape (n_samples_source, n_samples_target)
23
Ground cost matrix.
24
- reg: float
25
Entropic regularization parameter (>0).
26
- reg_m: float or tuple of floats
27
Marginal relaxation parameter(s). If float, uses same value for both
28
marginals. If tuple (reg_m1, reg_m2), uses different values.
29
- method: str, default='sinkhorn'
30
Algorithm variant. Options: 'sinkhorn', 'sinkhorn_stabilized',
31
'sinkhorn_translation_invariant'
32
- numItermax: int, default=1000
33
Maximum number of iterations.
34
- stopThr: float, default=1e-6
35
Convergence threshold on marginal violation.
36
- verbose: bool, default=False
37
Print iteration information.
38
- log: bool, default=False
39
Return optimization log.
40
- warn: bool, default=True
41
Warn if algorithm doesn't converge.
42
43
Returns:
44
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
45
Unbalanced optimal transport plan.
46
- log: dict (if log=True)
47
Contains 'err': convergence errors, 'mass_source': final source mass,
48
'mass_target': final target mass, 'u': source scaling, 'v': target scaling.
49
"""
50
51
def ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
52
"""
53
Solve unbalanced optimal transport and return cost only.
54
55
More efficient than sinkhorn_unbalanced() when only the cost is needed.
56
57
Parameters: Same as sinkhorn_unbalanced()
58
59
Returns:
60
- cost: float
61
Unbalanced optimal transport cost.
62
- log: dict (if log=True)
63
"""
64
65
def ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
66
"""
67
Unbalanced Sinkhorn-Knopp algorithm with multiplicative updates.
68
69
Classic formulation using diagonal scaling matrices for unbalanced case.
70
71
Parameters: Same as sinkhorn_unbalanced()
72
73
Returns:
74
- transport_plan: ndarray
75
- log: dict (if log=True)
76
"""
77
78
def ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
79
"""
80
Stabilized unbalanced Sinkhorn algorithm.
81
82
Uses tau-absorption technique to prevent numerical overflow while
83
handling unbalanced marginals.
84
85
Parameters:
86
- a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()
87
- tau: float, default=1e3
88
Absorption threshold for numerical stability.
89
- Other parameters same as sinkhorn_unbalanced()
90
91
Returns:
92
- transport_plan: ndarray
93
- log: dict (if log=True)
94
"""
95
96
def ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, c=None, rescale_plan=True, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
97
"""
98
Translation-invariant unbalanced Sinkhorn algorithm.
99
100
Uses a translation-invariant formulation that can be more numerically
101
stable and allows for better initialization strategies.
102
103
Parameters:
104
- a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()
105
- c: array-like, optional
106
Translation vector for numerical stability.
107
- rescale_plan: bool, default=True
108
Whether to rescale the final transport plan.
109
- Other parameters same as sinkhorn_unbalanced()
110
111
Returns:
112
- transport_plan: ndarray
113
- log: dict (if log=True)
114
"""
115
```
116
117
### Unbalanced Barycenters
118
119
```python { .api }
120
def ot.unbalanced.barycenter_unbalanced(A, M, reg, reg_m, weights=None, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
121
"""
122
Compute unbalanced Wasserstein barycenter.
123
124
Finds the barycenter that minimizes the sum of unbalanced transport costs
125
to all input distributions, allowing for mass creation/destruction.
126
127
Parameters:
128
- A: array-like, shape (n_samples, n_distributions)
129
Input distributions as columns. Need not be normalized.
130
- M: array-like, shape (n_samples, n_samples)
131
Ground cost matrix on barycenter support.
132
- reg: float
133
Entropic regularization parameter.
134
- reg_m: float
135
Marginal relaxation parameter.
136
- weights: array-like, shape (n_distributions,), optional
137
Weights for barycenter combination. Default is uniform.
138
- method: str, default='sinkhorn'
139
Algorithm variant for unbalanced transport computation.
140
- numItermax: int, default=1000
141
Maximum iterations for barycenter computation.
142
- stopThr: float, default=1e-6
143
Convergence threshold.
144
- verbose: bool, default=False
145
- log: bool, default=False
146
147
Returns:
148
- barycenter: ndarray, shape (n_samples,)
149
Unbalanced Wasserstein barycenter (may not sum to 1).
150
- log: dict (if log=True)
151
Contains convergence information and transport plans.
152
"""
153
154
def ot.unbalanced.barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
155
"""
156
Compute unbalanced barycenter using Sinkhorn algorithm.
157
158
Alternative implementation with explicit Sinkhorn iterations.
159
160
Parameters: Same as barycenter_unbalanced()
161
162
Returns:
163
- barycenter: ndarray
164
- log: dict (if log=True)
165
"""
166
167
def ot.unbalanced.barycenter_unbalanced_stabilized(A, M, reg, reg_m, tau=1e3, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
168
"""
169
Compute unbalanced barycenter using stabilized algorithm.
170
171
Parameters:
172
- A, M, reg, reg_m, weights: Same as barycenter_unbalanced()
173
- tau: float, default=1e3
174
Stabilization parameter.
175
- Other parameters same as barycenter_unbalanced()
176
177
Returns:
178
- barycenter: ndarray
179
- log: dict (if log=True)
180
"""
181
```
182
183
## MM Algorithm
184
185
```python { .api }
186
def ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
187
"""
188
Solve unbalanced optimal transport using MM (Majorization-Minimization) algorithm.
189
190
Alternative optimization approach that can handle different divergences
191
for marginal relaxation beyond KL divergence.
192
193
Parameters:
194
- a: array-like, shape (n_samples_source,)
195
Source distribution.
196
- b: array-like, shape (n_samples_target,)
197
Target distribution.
198
- M: array-like, shape (n_samples_source, n_samples_target)
199
Ground cost matrix.
200
- reg: float
201
Entropic regularization parameter.
202
- reg_m: float or tuple
203
Marginal relaxation parameter(s).
204
- div: str, default='kl'
205
Divergence for marginal relaxation. Options: 'kl', 'l2', 'tv'
206
- G0: array-like, optional
207
Initial transport plan.
208
- numItermax: int, default=1000
209
- stopThr: float, default=1e-6
210
- verbose: bool, default=False
211
- log: bool, default=False
212
213
Returns:
214
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
215
- log: dict (if log=True)
216
"""
217
218
def ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):
219
"""
220
MM algorithm for unbalanced OT returning cost only.
221
222
Parameters: Same as mm_unbalanced()
223
224
Returns:
225
- cost: float
226
Unbalanced transport cost.
227
- log: dict (if log=True)
228
"""
229
```
230
231
## L-BFGS-B Methods
232
233
```python { .api }
234
def ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):
235
"""
236
Solve unbalanced optimal transport using L-BFGS-B optimizer.
237
238
Uses quasi-Newton optimization method for solving the dual formulation
239
of unbalanced optimal transport, which can be more efficient for
240
large-scale problems.
241
242
Parameters:
243
- a: array-like, shape (n_samples_source,)
244
Source distribution.
245
- b: array-like, shape (n_samples_target,)
246
Target distribution.
247
- M: array-like, shape (n_samples_source, n_samples_target)
248
Ground cost matrix.
249
- reg: float
250
Entropic regularization parameter.
251
- reg_m: float or tuple
252
Marginal relaxation parameter(s).
253
- c: array-like, optional
254
Translation vector for numerical stability.
255
- reg_div: str, default='kl'
256
Divergence type for marginal regularization.
257
- G0: array-like, optional
258
Initial transport plan.
259
- numItermax: int, default=1000
260
Maximum outer iterations.
261
- numInnerItermax: int, default=10
262
Maximum inner iterations for line search.
263
- stopThr: float, default=1e-6
264
Convergence threshold for outer loop.
265
- stopThr2: float, default=1e-6
266
Convergence threshold for inner loop.
267
- verbose: bool, default=False
268
- log: bool, default=False
269
270
Returns:
271
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
272
- log: dict (if log=True)
273
Contains optimization details including L-BFGS-B convergence info.
274
"""
275
276
def ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):
277
"""
278
L-BFGS-B unbalanced OT returning cost only.
279
280
Parameters: Same as lbfgsb_unbalanced()
281
282
Returns:
283
- cost: float
284
- log: dict (if log=True)
285
"""
286
```
287
288
## Regularization and Divergences
289
290
The unbalanced transport framework supports different types of marginal relaxation:
291
292
### KL Divergence Relaxation
293
The most common choice using Kullback-Leibler divergence for marginal penalties:
294
```
295
KL(π₁|a) = Σᵢ π₁(i) log(π₁(i)/a(i)) - π₁(i) + a(i)
296
```
297
298
### Alternative Divergences
299
- **L2 Penalty**: `div='l2'` - Quadratic penalty on marginal violations
300
- **Total Variation**: `div='tv'` - L1 penalty on marginal differences
301
- **Custom Divergences**: User-defined penalty functions
302
303
## Usage Examples
304
305
### Basic Unbalanced Transport
306
```python
307
import ot
308
import numpy as np
309
310
# Create unbalanced distributions
311
a = np.array([0.6, 0.4]) # Source (sums to 1.0)
312
b = np.array([0.2, 0.3, 0.1]) # Target (sums to 0.6)
313
314
# Cost matrix
315
M = np.random.rand(2, 3)
316
317
# Regularization parameters
318
reg = 0.1 # Entropic regularization
319
reg_m = 0.5 # Marginal relaxation
320
321
# Solve unbalanced transport
322
plan_unbalanced = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, verbose=True)
323
cost_unbalanced = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m)
324
325
print("Unbalanced transport plan:")
326
print(plan_unbalanced)
327
print(f"Unbalanced cost: {cost_unbalanced}")
328
329
# Check mass conservation
330
source_mass = np.sum(plan_unbalanced, axis=1)
331
target_mass = np.sum(plan_unbalanced, axis=0)
332
print(f"Source masses: {source_mass} (original: {a})")
333
print(f"Target masses: {target_mass} (original: {b})")
334
```
335
336
### Different Marginal Regularizations
337
```python
338
# Different regularization for source and target
339
reg_m_source = 0.3
340
reg_m_target = 0.7
341
reg_m_tuple = (reg_m_source, reg_m_target)
342
343
plan_asym = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_tuple)
344
print("Asymmetric marginal regularization plan:")
345
print(plan_asym)
346
```
347
348
### Unbalanced Barycenter
349
```python
350
# Multiple unbalanced distributions
351
A = np.array([[0.6, 0.2, 0.4],
352
[0.4, 0.3, 0.6],
353
[0.0, 0.5, 0.0]]) # 3 distributions, different masses
354
355
# Cost matrix for barycenter space
356
M_bary = ot.dist(np.arange(3).reshape(-1, 1))
357
358
# Compute unbalanced barycenter
359
reg_bary = 0.05
360
reg_m_bary = 0.2
361
362
barycenter = ot.unbalanced.barycenter_unbalanced(A, M_bary, reg_bary, reg_m_bary, verbose=True)
363
364
print("Unbalanced barycenter:")
365
print(barycenter)
366
print(f"Barycenter mass: {np.sum(barycenter)}")
367
```
368
369
### MM Algorithm with Different Divergences
370
```python
371
# Use L2 divergence for marginal relaxation
372
plan_mm_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='l2')
373
cost_mm_l2 = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='l2')
374
375
print(f"MM L2 cost: {cost_mm_l2}")
376
377
# Use Total Variation divergence
378
plan_mm_tv = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='tv')
379
cost_mm_tv = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='tv')
380
381
print(f"MM TV cost: {cost_mm_tv}")
382
```
383
384
### Empirical Unbalanced Transport
385
```python
386
# Generate unbalanced sample data
387
np.random.seed(42)
388
n_source, n_target = 100, 80
389
X_s = np.random.randn(n_source, 2)
390
X_t = np.random.randn(n_target, 2) + 1
391
392
# Unbalanced weights (don't sum to 1)
393
a_unbalanced = np.random.exponential(0.8, n_source)
394
b_unbalanced = np.random.exponential(1.2, n_target)
395
396
# Compute cost matrix
397
M_empirical = ot.dist(X_s, X_t)
398
399
# Solve unbalanced transport
400
reg_emp = 0.1
401
reg_m_emp = 0.3
402
403
plan_emp = ot.unbalanced.sinkhorn_unbalanced(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)
404
cost_emp = ot.unbalanced.sinkhorn_unbalanced2(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)
405
406
print(f"Empirical unbalanced cost: {cost_emp}")
407
print(f"Original source mass: {np.sum(a_unbalanced):.3f}")
408
print(f"Original target mass: {np.sum(b_unbalanced):.3f}")
409
print(f"Transported mass: {np.sum(plan_emp):.3f}")
410
```
411
412
### Stabilized Algorithm for Extreme Cases
413
```python
414
# Very small regularization or large costs
415
reg_small = 1e-4
416
M_large = M * 100
417
418
# Use stabilized version
419
plan_stable = ot.unbalanced.sinkhorn_stabilized_unbalanced(
420
a, b, M_large, reg_small, reg_m, tau=1e2, verbose=True
421
)
422
423
print("Stabilized unbalanced transport completed")
424
```
425
426
### L-BFGS-B for Large-Scale Problems
427
```python
428
# For larger problems, L-BFGS-B can be more efficient
429
n_large = 500
430
a_large = np.random.exponential(1.0, n_large)
431
b_large = np.random.exponential(1.5, n_large)
432
M_large = np.random.rand(n_large, n_large)
433
434
# Use L-BFGS-B solver
435
plan_lbfgs = ot.unbalanced.lbfgsb_unbalanced(
436
a_large, b_large, M_large, reg, reg_m,
437
numItermax=100, verbose=True
438
)
439
cost_lbfgs = ot.unbalanced.lbfgsb_unbalanced2(
440
a_large, b_large, M_large, reg, reg_m
441
)
442
443
print(f"L-BFGS-B unbalanced cost: {cost_lbfgs}")
444
```
445
446
## Applications
447
448
### Comparing Unnormalized Data
449
Unbalanced transport is particularly useful when:
450
- Comparing histograms or distributions that naturally have different total masses
451
- Handling data with missing values or outliers
452
- Robust matching in the presence of noise
453
- Domain adaptation with different sample sizes
454
455
### Mass Creation and Destruction
456
The relaxed marginal constraints allow:
457
- **Mass Creation**: Transport plan can have row/column sums exceeding the original marginals
458
- **Mass Destruction**: Transport plan can have row/column sums below the original marginals
459
- **Outlier Handling**: Points with no good matches can have reduced mass
460
461
### Computational Advantages
462
- More robust convergence than balanced transport
463
- Better numerical stability with extreme regularization parameters
464
- Natural handling of datasets with different cardinalities
465
466
The `ot.unbalanced` module provides essential tools for real-world optimal transport applications where perfect mass conservation is not required or desired, offering both theoretical flexibility and computational advantages.