docs
0
# Partial Optimal Transport
1
2
The `ot.partial` module provides algorithms for partial optimal transport, where only a fraction of the total mass needs to be transported. This relaxation of the mass conservation constraint is particularly useful for comparing distributions with outliers, noise, or when dealing with unequal total masses.
3
4
## Core Partial Transport Functions
5
6
### Basic Partial Transport
7
8
```python { .api }
9
def ot.partial.partial_wasserstein(a, b, M, m=None, numItermax=1000000, log=False, **kwargs):
10
"""
11
Solve partial optimal transport problem.
12
13
Computes the optimal transport plan that transports at most mass m between
14
source and target distributions, relaxing the total mass conservation constraint.
15
This is useful for robust transport in the presence of outliers.
16
17
Parameters:
18
- a: array-like, shape (n_samples_source,)
19
Source distribution. Must be non-negative.
20
- b: array-like, shape (n_samples_target,)
21
Target distribution. Must be non-negative.
22
- M: array-like, shape (n_samples_source, n_samples_target)
23
Ground cost matrix between source and target samples.
24
- m: float, optional
25
Fraction of mass to transport. Must be in (0, min(sum(a), sum(b))].
26
If None, defaults to min(sum(a), sum(b)).
27
- numItermax: int, default=1000000
28
Maximum number of iterations for the underlying solver.
29
- log: bool, default=False
30
Return optimization log with convergence details.
31
- kwargs: dict
32
Additional arguments passed to the underlying linear programming solver.
33
34
Returns:
35
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
36
Partial optimal transport plan. The sum of the plan equals the transported mass m.
37
- log: dict (if log=True)
38
Contains 'cost': optimal transport cost, 'u': source dual variables,
39
'v': target dual variables, 'result_code': solver status.
40
41
Example:
42
a = np.array([0.5, 0.5])
43
b = np.array([0.3, 0.4, 0.3])
44
M = np.random.rand(2, 3)
45
plan = ot.partial.partial_wasserstein(a, b, M, m=0.7)
46
print(f"Transported mass: {np.sum(plan)}")
47
"""
48
49
def ot.partial.partial_wasserstein2(a, b, M, m=None, numItermax=1000000, log=False, **kwargs):
50
"""
51
Solve partial optimal transport and return cost only.
52
53
More efficient than partial_wasserstein() when only the optimal cost is needed.
54
55
Parameters: Same as partial_wasserstein()
56
57
Returns:
58
- cost: float
59
Partial optimal transport cost (optimal objective value).
60
- log: dict (if log=True)
61
Optimization information.
62
63
Example:
64
cost = ot.partial.partial_wasserstein2(a, b, M, m=0.7)
65
print(f"Partial transport cost: {cost}")
66
"""
67
68
def ot.partial.partial_wasserstein_lagrange(a, b, M, reg_m=None, numItermax=1000000, log=False, **kwargs):
69
"""
70
Solve partial optimal transport using Lagrangian relaxation.
71
72
Alternative formulation where the mass constraint is regularized using
73
a Lagrangian multiplier instead of being enforced as a hard constraint.
74
75
Parameters:
76
- a: array-like, shape (n_samples_source,)
77
Source distribution.
78
- b: array-like, shape (n_samples_target,)
79
Target distribution.
80
- M: array-like, shape (n_samples_source, n_samples_target)
81
Cost matrix.
82
- reg_m: float, optional
83
Regularization parameter for mass constraint. Higher values enforce
84
stronger mass conservation.
85
- numItermax: int, default=1000000
86
Maximum iterations.
87
- log: bool, default=False
88
- kwargs: dict
89
Additional solver arguments.
90
91
Returns:
92
- transport_plan: ndarray
93
Partial transport plan with regularized mass constraint.
94
- log: dict (if log=True)
95
"""
96
```
97
98
### Entropic Partial Transport
99
100
```python { .api }
101
def ot.partial.entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
102
"""
103
Solve entropic regularized partial optimal transport.
104
105
Combines partial transport with entropic regularization for better
106
computational properties and differentiability.
107
108
Parameters:
109
- a: array-like, shape (n_samples_source,)
110
Source distribution.
111
- b: array-like, shape (n_samples_target,)
112
Target distribution.
113
- M: array-like, shape (n_samples_source, n_samples_target)
114
Cost matrix.
115
- reg: float
116
Entropic regularization parameter. Higher values give smoother solutions.
117
- m: float, optional
118
Mass to transport. If None, uses min(sum(a), sum(b)).
119
- numItermax: int, default=1000
120
Maximum Sinkhorn-like iterations.
121
- stopThr: float, default=1e-9
122
Convergence threshold on marginal constraints.
123
- verbose: bool, default=False
124
Print iteration information.
125
- log: bool, default=False
126
Return optimization log.
127
128
Returns:
129
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
130
Entropic partial transport plan.
131
- log: dict (if log=True)
132
Contains 'err': convergence errors, 'mass': transported mass,
133
'u': source scaling factors, 'v': target scaling factors.
134
135
Example:
136
# Entropic partial transport with regularization
137
reg = 0.1 # Regularization strength
138
plan = ot.partial.entropic_partial_wasserstein(a, b, M, reg, m=0.8)
139
"""
140
141
def ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
142
"""
143
Solve partial Gromov-Wasserstein problem.
144
145
Extends partial transport to structured data by combining partial transport
146
with Gromov-Wasserstein distance for comparing metric spaces.
147
148
Parameters:
149
- C1: array-like, shape (n1, n1)
150
Intra-structure cost matrix for source space.
151
- C2: array-like, shape (n2, n2)
152
Intra-structure cost matrix for target space.
153
- p: array-like, shape (n1,)
154
Distribution over source space.
155
- q: array-like, shape (n2,)
156
Distribution over target space.
157
- m: float, optional
158
Mass to transport. If None, uses min(sum(p), sum(q)).
159
- loss_fun: str or callable, default='square_loss'
160
Loss function for structure comparison.
161
- alpha: float, default=0.5
162
Step size parameter for optimization.
163
- armijo: bool, default=False
164
Use Armijo line search for step size adaptation.
165
- log: bool, default=False
166
- max_iter: int, default=1000
167
- tol_rel: float, default=1e-9
168
Relative tolerance for convergence.
169
- tol_abs: float, default=1e-9
170
Absolute tolerance for convergence.
171
- kwargs: dict
172
173
Returns:
174
- transport_plan: ndarray, shape (n1, n2)
175
Partial Gromov-Wasserstein transport plan.
176
- log: dict (if log=True)
177
"""
178
179
def ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
180
"""
181
Solve partial Gromov-Wasserstein and return cost only.
182
183
Parameters: Same as partial_gromov_wasserstein()
184
185
Returns:
186
- cost: float
187
Partial Gromov-Wasserstein cost.
188
- log: dict (if log=True)
189
"""
190
191
def ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
192
"""
193
Solve entropic regularized partial Gromov-Wasserstein.
194
195
Combines entropic regularization with partial Gromov-Wasserstein for
196
structured partial transport with computational advantages.
197
198
Parameters:
199
- C1: array-like, shape (n1, n1)
200
Source structure matrix.
201
- C2: array-like, shape (n2, n2)
202
Target structure matrix.
203
- p: array-like, shape (n1,)
204
Source distribution.
205
- q: array-like, shape (n2,)
206
Target distribution.
207
- reg: float
208
Entropic regularization parameter.
209
- m: float, optional
210
Mass to transport.
211
- loss_fun: str or callable, default='square_loss'
212
- G0: array-like, optional
213
Initial transport plan.
214
- max_iter: int, default=1000
215
- tol: float, default=1e-9
216
- verbose: bool, default=False
217
- log: bool, default=False
218
219
Returns:
220
- transport_plan: ndarray
221
- log: dict (if log=True)
222
"""
223
224
def ot.partial.entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
225
"""
226
Entropic partial Gromov-Wasserstein cost only.
227
228
Parameters: Same as entropic_partial_gromov_wasserstein()
229
230
Returns:
231
- cost: float
232
- log: dict (if log=True)
233
"""
234
```
235
236
## Utility Functions
237
238
```python { .api }
239
def ot.partial.gwgrad_partial(C1, C2, T):
240
"""
241
Compute gradient for partial Gromov-Wasserstein problem.
242
243
Specialized gradient computation for the partial GW objective function,
244
accounting for the relaxed mass constraints.
245
246
Parameters:
247
- C1: array-like, shape (n1, n1)
248
Source structure matrix.
249
- C2: array-like, shape (n2, n2)
250
Target structure matrix.
251
- T: array-like, shape (n1, n2)
252
Current transport plan.
253
254
Returns:
255
- gradient: ndarray, shape (n1, n2)
256
Gradient of partial GW objective with respect to transport plan.
257
"""
258
259
def ot.partial.gwloss_partial(C1, C2, T):
260
"""
261
Compute partial Gromov-Wasserstein loss value.
262
263
Evaluates the objective function value for the partial GW problem
264
given current transport plan.
265
266
Parameters:
267
- C1: array-like, shape (n1, n1)
268
Source structure matrix.
269
- C2: array-like, shape (n2, n2)
270
Target structure matrix.
271
- T: array-like, shape (n1, n2)
272
Transport plan.
273
274
Returns:
275
- loss: float
276
Partial GW loss value.
277
"""
278
```
279
280
## Key Concepts and Applications
281
282
### Mass Relaxation Benefits
283
Partial optimal transport provides several advantages over standard OT:
284
285
1. **Outlier Robustness**: Outliers can be left untransported rather than forcing poor matches
286
2. **Unequal Masses**: Handles distributions with different total masses naturally
287
3. **Noise Handling**: Allows ignoring noisy or irrelevant parts of distributions
288
4. **Computational Efficiency**: Can lead to sparser solutions and faster algorithms
289
290
### Theoretical Properties
291
- **Optimal Substructure**: Solution structure is related to standard OT
292
- **Metric Properties**: Under certain conditions, partial transport defines proper metrics
293
- **Robustness**: More stable to distribution perturbations than exact transport
294
295
## Usage Examples
296
297
### Basic Partial Transport
298
```python
299
import ot
300
import numpy as np
301
302
# Create distributions with different masses
303
np.random.seed(42)
304
a = np.array([0.6, 0.4]) # Source distribution
305
b = np.array([0.3, 0.3, 0.2]) # Target has less total mass
306
307
# Cost matrix
308
M = np.array([[1.0, 2.0, 3.0],
309
[2.0, 1.0, 4.0]])
310
311
# Standard transport would require normalizing distributions
312
# Partial transport can handle different masses directly
313
314
# Transport only 70% of available mass
315
m = 0.7
316
317
# Compute partial transport
318
plan_partial = ot.partial.partial_wasserstein(a, b, M, m=m, log=True)
319
cost_partial = ot.partial.partial_wasserstein2(a, b, M, m=m)
320
321
print(f"Partial transport plan:\n{plan_partial[0]}")
322
print(f"Transported mass: {np.sum(plan_partial[0]):.3f}")
323
print(f"Partial transport cost: {cost_partial:.3f}")
324
print(f"Optimization info: {plan_partial[1]}")
325
```
326
327
### Comparison with Full Transport
328
```python
329
# Compare partial vs full transport
330
masses_to_test = [0.5, 0.7, 0.9, min(np.sum(a), np.sum(b))]
331
332
print("Mass fraction -> Partial cost")
333
for m in masses_to_test:
334
cost = ot.partial.partial_wasserstein2(a, b, M, m=m)
335
print(f"{m:.1f} -> {cost:.4f}")
336
337
# Full transport for comparison (after normalization)
338
a_norm = a / np.sum(a)
339
b_norm = b / np.sum(b)
340
cost_full = ot.emd2(a_norm, b_norm, M)
341
print(f"Full transport (normalized): {cost_full:.4f}")
342
```
343
344
### Entropic Partial Transport
345
```python
346
# Add entropic regularization to partial transport
347
reg = 0.1 # Regularization parameter
348
m = 0.8 # Partial mass
349
350
# Entropic partial transport
351
plan_entropic = ot.partial.entropic_partial_wasserstein(
352
a, b, M, reg, m=m, verbose=True, log=True
353
)
354
355
print(f"Entropic partial plan:\n{plan_entropic[0]}")
356
print(f"Converged in {len(plan_entropic[1]['err'])} iterations")
357
print(f"Final error: {plan_entropic[1]['err'][-1]:.2e}")
358
```
359
360
### Outlier Detection Example
361
```python
362
# Demonstrate outlier robustness
363
# Create distributions where one sample is an outlier
364
365
n_source, n_target = 50, 60
366
X_source = np.random.randn(n_source, 2)
367
X_target = np.random.randn(n_target, 2) + [1, 1]
368
369
# Add outlier to source
370
X_source[-1] = [5, 5] # Far outlier
371
372
# Create cost matrix
373
M_outlier = ot.dist(X_source, X_target)
374
375
# Uniform distributions
376
a_unif = ot.unif(n_source)
377
b_unif = ot.unif(n_target)
378
379
# Full transport (must transport outlier)
380
plan_full = ot.emd(a_unif, b_unif, M_outlier)
381
cost_full = ot.emd2(a_unif, b_unif, M_outlier)
382
383
# Partial transport (can ignore outlier)
384
m_partial = 0.9 # Transport 90% of mass
385
plan_partial = ot.partial.partial_wasserstein(a_unif, b_unif, M_outlier, m=m_partial)
386
cost_partial = ot.partial.partial_wasserstein2(a_unif, b_unif, M_outlier, m=m_partial)
387
388
print(f"Full transport cost: {cost_full:.3f}")
389
print(f"Partial transport cost: {cost_partial:.3f}")
390
print(f"Cost reduction: {(cost_full - cost_partial) / cost_full * 100:.1f}%")
391
392
# Check if outlier is transported
393
outlier_transported = np.sum(plan_partial[0][-1, :]) # Last source sample
394
print(f"Outlier transported mass: {outlier_transported:.3f}")
395
```
396
397
### Partial Gromov-Wasserstein
398
```python
399
# Structured partial transport example
400
n1, n2 = 20, 25
401
402
# Create structure matrices (e.g., distance matrices)
403
X1 = np.random.randn(n1, 3)
404
X2 = np.random.randn(n2, 3)
405
C1 = ot.dist(X1)
406
C2 = ot.dist(X2)
407
408
# Distributions
409
p = ot.unif(n1)
410
q = ot.unif(n2)
411
412
# Partial mass
413
m_gw = 0.8
414
415
# Partial Gromov-Wasserstein
416
plan_pgw = ot.partial.partial_gromov_wasserstein(
417
C1, C2, p, q, m=m_gw, verbose=True, log=True
418
)
419
cost_pgw = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m_gw)
420
421
print(f"Partial GW cost: {cost_pgw:.4f}")
422
print(f"PGW transported mass: {np.sum(plan_pgw[0]):.3f}")
423
424
# Entropic version
425
reg_gw = 0.05
426
plan_epgw = ot.partial.entropic_partial_gromov_wasserstein(
427
C1, C2, p, q, reg_gw, m=m_gw, verbose=True
428
)
429
430
print(f"Entropic PGW completed")
431
```
432
433
### Mass Selection Strategy
434
```python
435
# Explore different mass levels systematically
436
mass_fractions = np.linspace(0.1, 1.0, 10)
437
costs = []
438
439
for frac in mass_fractions:
440
m = frac * min(np.sum(a), np.sum(b))
441
cost = ot.partial.partial_wasserstein2(a, b, M, m=m)
442
costs.append(cost)
443
444
print("Mass fraction -> Cost")
445
for frac, cost in zip(mass_fractions, costs):
446
print(f"{frac:.1f} -> {cost:.4f}")
447
448
# Find elbow point (good trade-off between mass and cost)
449
cost_gradients = np.diff(costs)
450
optimal_idx = np.argmin(cost_gradients) + 1
451
optimal_mass = mass_fractions[optimal_idx] * min(np.sum(a), np.sum(b))
452
453
print(f"Suggested optimal mass: {optimal_mass:.3f}")
454
```
455
456
### Robustness Analysis
457
```python
458
# Test robustness to noise
459
noise_levels = [0.0, 0.1, 0.2, 0.5]
460
costs_full = []
461
costs_partial = []
462
463
for noise in noise_levels:
464
# Add noise to cost matrix
465
M_noisy = M + noise * np.random.randn(*M.shape)
466
467
# Full transport
468
cost_full_noisy = ot.emd2(a/np.sum(a), b/np.sum(b), M_noisy)
469
costs_full.append(cost_full_noisy)
470
471
# Partial transport (70% mass)
472
cost_partial_noisy = ot.partial.partial_wasserstein2(a, b, M_noisy, m=0.7)
473
costs_partial.append(cost_partial_noisy)
474
475
print("Noise level -> Full cost | Partial cost")
476
for noise, cost_f, cost_p in zip(noise_levels, costs_full, costs_partial):
477
print(f"{noise:.1f} -> {cost_f:.4f} | {cost_p:.4f}")
478
```
479
480
### Large-Scale Efficiency
481
```python
482
# Compare computational efficiency for larger problems
483
sizes = [50, 100, 200]
484
485
for n in sizes:
486
# Generate larger problem
487
X_large_s = np.random.randn(n, 5)
488
X_large_t = np.random.randn(n, 5)
489
M_large = ot.dist(X_large_s, X_large_t)
490
a_large = ot.unif(n)
491
b_large = ot.unif(n)
492
493
# Time partial transport
494
import time
495
496
tic = time.time()
497
cost_partial_large = ot.partial.partial_wasserstein2(
498
a_large, b_large, M_large, m=0.8
499
)
500
time_partial = time.time() - tic
501
502
# Time entropic partial (faster for large problems)
503
tic = time.time()
504
plan_entropic_large = ot.partial.entropic_partial_wasserstein(
505
a_large, b_large, M_large, reg=0.1, m=0.8
506
)
507
time_entropic = time.time() - tic
508
509
print(f"Size {n}x{n}: Partial={time_partial:.3f}s, Entropic={time_entropic:.3f}s")
510
```
511
512
## Applications
513
514
### Computer Vision
515
- **Object matching with occlusion**: Handle partially visible objects
516
- **Image registration**: Align images with different fields of view
517
- **Shape matching**: Compare shapes with missing parts
518
519
### Machine Learning
520
- **Outlier-robust clustering**: Cluster data while ignoring outliers
521
- **Domain adaptation**: Handle distributions with different supports
522
- **Few-shot learning**: Match with limited target samples
523
524
### Bioinformatics
525
- **Cell matching**: Compare cell populations with different sizes
526
- **Sequence alignment**: Allow for insertions/deletions in sequences
527
- **Drug discovery**: Match molecular libraries of different sizes
528
529
The `ot.partial` module provides powerful tools for robust optimal transport, enabling practical applications where perfect mass conservation is neither required nor desired.