docs
0
# Gromov-Wasserstein Distances
1
2
The `ot.gromov` module provides algorithms for computing Gromov-Wasserstein (GW) distances and their variants. These methods enable optimal transport between structured data by comparing the internal geometry of metric spaces rather than requiring a common embedding space.
3
4
## Core Gromov-Wasserstein Functions
5
6
### Basic GW Distance Computation
7
8
```python { .api }
9
def ot.gromov.gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
10
"""
11
Compute Gromov-Wasserstein distance between two metric spaces.
12
13
Solves the quadratic assignment problem to find optimal correspondences
14
between points in different metric spaces by preserving pairwise distances.
15
16
Parameters:
17
- C1: array-like, shape (n1, n1)
18
Intra-structure cost matrix for source space (distances/similarities).
19
- C2: array-like, shape (n2, n2)
20
Intra-structure cost matrix for target space.
21
- p: array-like, shape (n1,)
22
Distribution over source space. Must be positive and sum to 1.
23
- q: array-like, shape (n2,)
24
Distribution over target space. Must be positive and sum to 1.
25
- loss_fun: str or callable, default='square_loss'
26
Loss function for structure preservation. Options: 'square_loss', 'kl_loss'
27
or custom function with signature loss_fun(C1, C2, T).
28
- alpha: float, default=0.5
29
Step size parameter for the gradient descent algorithm.
30
- armijo: bool, default=False
31
Use Armijo line search for adaptive step size.
32
- log: bool, default=False
33
Return optimization log with convergence details.
34
- max_iter: int, default=1000
35
Maximum number of iterations.
36
- tol_rel: float, default=1e-9
37
Relative tolerance for convergence.
38
- tol_abs: float, default=1e-9
39
Absolute tolerance for convergence.
40
41
Returns:
42
- transport_plan: ndarray, shape (n1, n2)
43
Optimal GW transport plan between the two spaces.
44
- log: dict (if log=True)
45
Contains 'gw_dist': GW distance, 'err': convergence errors,
46
'T': transport plans at each iteration.
47
"""
48
49
def ot.gromov.gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
50
"""
51
Compute Gromov-Wasserstein squared distance (cost only).
52
53
More efficient when only the distance value is needed.
54
55
Parameters: Same as gromov_wasserstein()
56
57
Returns:
58
- gw_distance: float
59
Gromov-Wasserstein distance between the two spaces.
60
- log: dict (if log=True)
61
"""
62
63
def ot.gromov.solve_gromov_linesearch(C1, C2, p, q, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
64
"""
65
Solve GW problem with automatic line search for optimal step size.
66
67
Parameters:
68
- C1, C2: array-like
69
Cost matrices for source and target spaces.
70
- p, q: array-like
71
Distributions over source and target spaces.
72
- loss_fun: str or callable
73
Loss function for GW computation.
74
- alpha_min: float, optional
75
Minimum step size for line search.
76
- alpha_max: float, optional
77
Maximum step size for line search.
78
- log: bool, default=False
79
- numItermax: int, default=1000
80
- stopThr: float, default=1e-9
81
- verbose: bool, default=False
82
83
Returns:
84
- transport_plan: ndarray
85
- log: dict (if log=True)
86
"""
87
```
88
89
### Fused Gromov-Wasserstein
90
91
```python { .api }
92
def ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
93
"""
94
Compute Fused Gromov-Wasserstein distance combining structure and features.
95
96
Combines standard optimal transport (based on feature cost M) with
97
Gromov-Wasserstein transport (based on structure costs C1, C2).
98
99
Parameters:
100
- M: array-like, shape (n1, n2)
101
Feature cost matrix between source and target samples.
102
- C1: array-like, shape (n1, n1)
103
Intra-structure cost matrix for source space.
104
- C2: array-like, shape (n2, n2)
105
Intra-structure cost matrix for target space.
106
- p: array-like, shape (n1,)
107
Source distribution.
108
- q: array-like, shape (n2,)
109
Target distribution.
110
- loss_fun: str or callable, default='square_loss'
111
Loss function for structure preservation.
112
- alpha: float, default=0.5
113
Trade-off parameter between structure (α) and features (1-α).
114
α=1 gives pure GW, α=0 gives pure Wasserstein.
115
- armijo: bool, default=False
116
Use Armijo line search.
117
- log: bool, default=False
118
- max_iter: int, default=1000
119
- tol_rel: float, default=1e-9
120
- tol_abs: float, default=1e-9
121
122
Returns:
123
- transport_plan: ndarray, shape (n1, n2)
124
Optimal FGW transport plan.
125
- log: dict (if log=True)
126
"""
127
128
def ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
129
"""
130
Compute Fused Gromov-Wasserstein squared distance (cost only).
131
132
Parameters: Same as fused_gromov_wasserstein()
133
134
Returns:
135
- fgw_distance: float
136
- log: dict (if log=True)
137
"""
138
```
139
140
## Barycenter Algorithms
141
142
```python { .api }
143
def ot.gromov.gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None, **kwargs):
144
"""
145
Compute Gromov-Wasserstein barycenter of multiple metric spaces.
146
147
Finds the barycenter space that minimizes the sum of GW distances
148
to all input spaces, optimizing both the structure and distribution.
149
150
Parameters:
151
- N: int
152
Size of the barycenter space (number of points).
153
- Cs: list of arrays
154
List of intra-structure cost matrices for input spaces.
155
Each Cs[i] has shape (ni, ni).
156
- ps: list of arrays
157
List of distributions for input spaces.
158
Each ps[i] has shape (ni,).
159
- p: array-like, shape (N,)
160
Distribution for the barycenter space.
161
- lambdas: array-like, shape (n_spaces,)
162
Weights for combining input spaces in barycenter.
163
- loss_fun: str or callable, default='square_loss'
164
Loss function for GW computation.
165
- max_iter: int, default=1000
166
Maximum iterations for barycenter optimization.
167
- tol: float, default=1e-9
168
Convergence tolerance.
169
- verbose: bool, default=False
170
Print optimization information.
171
- log: bool, default=False
172
Return optimization log.
173
- init_C: array-like, shape (N, N), optional
174
Initial barycenter structure matrix. Random if None.
175
- random_state: int, optional
176
Random seed for reproducible initialization.
177
178
Returns:
179
- barycenter_structure: ndarray, shape (N, N)
180
Optimal barycenter intra-structure cost matrix.
181
- log: dict (if log=True)
182
Contains convergence information and transport plans.
183
"""
184
185
def ot.gromov.fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
186
"""
187
Compute Fused Gromov-Wasserstein barycenter with features and structure.
188
189
Parameters:
190
- N: int
191
Barycenter size.
192
- Ys: list of arrays
193
List of feature matrices for input spaces.
194
- Cs: list of arrays
195
List of structure matrices for input spaces.
196
- ps: list of arrays
197
List of distributions for input spaces.
198
- lambdas: array-like
199
Weights for space combination.
200
- alpha: float
201
Trade-off between structure and features.
202
- fixed_structure: bool, default=False
203
Whether to fix the barycenter structure.
204
- fixed_features: bool, default=False
205
Whether to fix the barycenter features.
206
- p: array-like, optional
207
Barycenter distribution.
208
- loss_fun: str or callable, default='square_loss'
209
- max_iter: int, default=100
210
- tol: float, default=1e-9
211
- verbose: bool, default=False
212
- log: bool, default=False
213
- init_C: array-like, optional
214
Initial barycenter structure.
215
- init_X: array-like, optional
216
Initial barycenter features.
217
- random_state: int, optional
218
219
Returns:
220
- barycenter_features: ndarray, shape (N, d)
221
- barycenter_structure: ndarray, shape (N, N)
222
- log: dict (if log=True)
223
"""
224
```
225
226
## Entropic Regularized Methods
227
228
```python { .api }
229
def ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
230
"""
231
Compute entropic regularized Gromov-Wasserstein distance.
232
233
Combines GW formulation with entropic regularization for better
234
computational properties and differentiability.
235
236
Parameters:
237
- C1: array-like, shape (n1, n1)
238
Source structure matrix.
239
- C2: array-like, shape (n2, n2)
240
Target structure matrix.
241
- p: array-like, shape (n1,)
242
Source distribution.
243
- q: array-like, shape (n2,)
244
Target distribution.
245
- loss_fun: str or callable, default='square_loss'
246
- epsilon: float, default=0.1
247
Entropic regularization parameter.
248
- symmetric: bool, optional
249
Whether loss function is symmetric.
250
- G0: array-like, optional
251
Initial transport plan.
252
- max_iter: int, default=1000
253
- tol: float, default=1e-9
254
- verbose: bool, default=False
255
- log: bool, default=False
256
257
Returns:
258
- transport_plan: ndarray, shape (n1, n2)
259
- log: dict (if log=True)
260
"""
261
262
def ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
263
"""
264
Compute entropic regularized GW distance (cost only).
265
266
Parameters: Same as entropic_gromov_wasserstein()
267
268
Returns:
269
- gw_distance: float
270
- log: dict (if log=True)
271
"""
272
273
def ot.gromov.entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
274
"""
275
Compute entropic regularized GW barycenters.
276
277
Parameters:
278
- N: int
279
Barycenter size.
280
- Cs: list of arrays
281
Structure matrices.
282
- ps: list of arrays
283
Input distributions.
284
- p: array-like
285
Barycenter distribution.
286
- lambdas: array-like
287
Combination weights.
288
- loss_fun: str or callable, default='square_loss'
289
- epsilon: float, default=0.1
290
Entropic regularization.
291
- symmetric: bool, default=True
292
- max_iter: int, default=1000
293
- tol: float, default=1e-9
294
- verbose: bool, default=False
295
- log: bool, default=False
296
- init_C: array-like, optional
297
- random_state: int, optional
298
299
Returns:
300
- barycenter_structure: ndarray, shape (N, N)
301
- log: dict (if log=True)
302
"""
303
304
def ot.gromov.entropic_fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
305
"""
306
Compute entropic regularized Fused GW distance.
307
308
Parameters:
309
- M: array-like, shape (n1, n2)
310
Feature cost matrix.
311
- C1, C2: array-like
312
Structure matrices.
313
- p, q: array-like
314
Distributions.
315
- loss_fun: str or callable, default='square_loss'
316
- epsilon: float, default=0.1
317
Entropic regularization.
318
- alpha: float, default=0.5
319
Structure/feature trade-off.
320
- symmetric: bool, optional
321
- G0: array-like, optional
322
Initial transport plan.
323
- max_iter: int, default=1000
324
- tol: float, default=1e-9
325
- verbose: bool, default=False
326
- log: bool, default=False
327
328
Returns:
329
- transport_plan: ndarray, shape (n1, n2)
330
- log: dict (if log=True)
331
"""
332
333
def ot.gromov.entropic_fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
334
"""
335
Compute entropic regularized FGW distance (cost only).
336
337
Parameters: Same as entropic_fused_gromov_wasserstein()
338
339
Returns:
340
- fgw_distance: float
341
- log: dict (if log=True)
342
"""
343
344
def ot.gromov.entropic_fused_gromov_barycenters(N, Ys, Cs, ps, lambdas, alpha, epsilon=0.1, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None):
345
"""
346
Compute entropic regularized FGW barycenters.
347
348
Parameters: Similar to fgw_barycenters() with additional epsilon parameter
349
350
Returns:
351
- barycenter_features: ndarray
352
- barycenter_structure: ndarray
353
- log: dict (if log=True)
354
"""
355
```
356
357
## Semi-relaxed Methods
358
359
```python { .api }
360
def ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
361
"""
362
Compute semi-relaxed Gromov-Wasserstein distance.
363
364
Relaxes the constraint on one marginal, allowing for partial transport
365
from source to target while preserving target marginal.
366
367
Parameters:
368
- C1: array-like, shape (n1, n1)
369
Source structure matrix.
370
- C2: array-like, shape (n2, n2)
371
Target structure matrix.
372
- p: array-like, shape (n1,)
373
Source distribution (will be relaxed).
374
- loss_fun: str or callable, default='square_loss'
375
- symmetric: bool, optional
376
- alpha: float, default=0.5
377
Step size parameter.
378
- G0: array-like, optional
379
Initial transport plan.
380
- log: bool, default=False
381
- max_iter: int, default=1000
382
- tol_rel: float, default=1e-9
383
- tol_abs: float, default=1e-9
384
385
Returns:
386
- transport_plan: ndarray, shape (n1, n2)
387
Semi-relaxed transport plan.
388
- log: dict (if log=True)
389
"""
390
391
def ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
392
"""
393
Compute semi-relaxed GW distance (cost only).
394
395
Parameters: Same as semirelaxed_gromov_wasserstein()
396
397
Returns:
398
- sr_gw_distance: float
399
- log: dict (if log=True)
400
"""
401
402
def ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
403
"""
404
Compute semi-relaxed Fused GW distance.
405
406
Parameters:
407
- M: array-like, shape (n1, n2)
408
Feature cost matrix.
409
- Other parameters same as semirelaxed_gromov_wasserstein()
410
411
Returns:
412
- transport_plan: ndarray, shape (n1, n2)
413
- log: dict (if log=True)
414
"""
415
416
def ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
417
"""
418
Compute semi-relaxed FGW distance (cost only).
419
420
Parameters: Same as semirelaxed_fused_gromov_wasserstein()
421
422
Returns:
423
- sr_fgw_distance: float
424
- log: dict (if log=True)
425
"""
426
427
def ot.gromov.solve_semirelaxed_gromov_linesearch(C1, C2, p, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
428
"""
429
Solve semi-relaxed GW with line search optimization.
430
431
Parameters: Similar to solve_gromov_linesearch()
432
433
Returns:
434
- transport_plan: ndarray
435
- log: dict (if log=True)
436
"""
437
```
438
439
## Partial Methods
440
441
```python { .api }
442
def ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
443
"""
444
Compute partial Gromov-Wasserstein distance.
445
446
Allows transport of only a fraction of the total mass, useful when
447
spaces have different sizes or contain outliers.
448
449
Parameters:
450
- C1: array-like, shape (n1, n1)
451
Source structure matrix.
452
- C2: array-like, shape (n2, n2)
453
Target structure matrix.
454
- p: array-like, shape (n1,)
455
Source distribution.
456
- q: array-like, shape (n2,)
457
Target distribution.
458
- m: float, optional
459
Fraction of mass to transport (default: min(sum(p), sum(q))).
460
- loss_fun: str or callable, default='square_loss'
461
- alpha: float, default=0.5
462
- armijo: bool, default=False
463
- log: bool, default=False
464
- max_iter: int, default=1000
465
- tol_rel: float, default=1e-9
466
- tol_abs: float, default=1e-9
467
468
Returns:
469
- transport_plan: ndarray, shape (n1, n2)
470
Partial GW transport plan.
471
- log: dict (if log=True)
472
"""
473
474
def ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
475
"""
476
Compute partial GW distance (cost only).
477
478
Parameters: Same as partial_gromov_wasserstein()
479
480
Returns:
481
- partial_gw_distance: float
482
- log: dict (if log=True)
483
"""
484
485
def ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
486
"""
487
Compute entropic regularized partial GW distance.
488
489
Parameters:
490
- C1, C2: array-like
491
Structure matrices.
492
- p, q: array-like
493
Distributions.
494
- reg: float
495
Entropic regularization parameter.
496
- m: float, optional
497
Mass to transport.
498
- loss_fun: str or callable, default='square_loss'
499
- G0: array-like, optional
500
- max_iter: int, default=1000
501
- tol: float, default=1e-9
502
- verbose: bool, default=False
503
- log: bool, default=False
504
505
Returns:
506
- transport_plan: ndarray
507
- log: dict (if log=True)
508
"""
509
510
def ot.gromov.entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
511
"""
512
Compute entropic regularized partial GW distance (cost only).
513
514
Parameters: Same as entropic_partial_gromov_wasserstein()
515
516
Returns:
517
- partial_gw_distance: float
518
- log: dict (if log=True)
519
"""
520
```
521
522
## Dictionary Learning Methods
523
524
```python { .api }
525
def ot.gromov.gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0.0, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None, **kwargs):
526
"""
527
Learn dictionary of structures using GW distances.
528
529
Learns a dictionary of prototype structures that can represent
530
input structures as sparse combinations.
531
532
Parameters:
533
- Cs: list of arrays
534
Input structure matrices to learn from.
535
- D: int
536
Dictionary size (number of atoms).
537
- nt: int
538
Size of each dictionary atom.
539
- reg: float, default=0.0
540
Entropic regularization for GW computation.
541
- ps: list of arrays, optional
542
Distributions for input structures.
543
- q: array-like, optional
544
Distribution for dictionary atoms.
545
- epochs: int, default=20
546
Number of learning epochs.
547
- batch_size: int, default=32
548
Mini-batch size for learning.
549
- learning_rate: float, default=1.0
550
Learning rate for dictionary updates.
551
- proj_sparse_regul: float, default=0.1
552
Sparsity regularization for projections.
553
- verbose: bool, default=False
554
- random_state: int, optional
555
556
Returns:
557
- dictionary: list of arrays
558
Learned dictionary of structure matrices.
559
- log: dict
560
Learning statistics and convergence information.
561
"""
562
563
def ot.gromov.gromov_wasserstein_linear_unmixing(C, Cdict, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False, **kwargs):
564
"""
565
Unmix structure using learned GW dictionary.
566
567
Decomposes input structure as sparse combination of dictionary atoms.
568
569
Parameters:
570
- C: array-like, shape (n, n)
571
Structure matrix to unmix.
572
- Cdict: list of arrays
573
Dictionary of structure atoms.
574
- reg: float, default=0.0
575
Entropic regularization.
576
- p: array-like, optional
577
Distribution for input structure.
578
- q: array-like, optional
579
Distribution for dictionary atoms.
580
- tol_outer: float, default=1e-6
581
Outer loop tolerance.
582
- tol_inner: float, default=1e-6
583
Inner loop tolerance.
584
- max_iter_outer: int, default=20
585
- max_iter_inner: int, default=200
586
- verbose: bool, default=False
587
588
Returns:
589
- coefficients: ndarray
590
Sparse coefficients for dictionary combination.
591
- log: dict
592
Unmixing optimization information.
593
"""
594
595
def ot.gromov.fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, reg=0.0, alpha=0.5, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None):
596
"""
597
Learn dictionary for FGW (structure + features).
598
599
Parameters: Extends gromov_wasserstein_dictionary_learning() with:
600
- Ys: list of arrays
601
Feature matrices for input data.
602
- alpha: float, default=0.5
603
Structure/feature trade-off.
604
605
Returns:
606
- structure_dictionary: list of arrays
607
- feature_dictionary: list of arrays
608
- log: dict
609
"""
610
611
def ot.gromov.fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha=0.5, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False):
612
"""
613
Unmix FGW data using learned dictionary.
614
615
Parameters: Extends gromov_wasserstein_linear_unmixing() with:
616
- Y: array-like
617
Feature matrix to unmix.
618
- Ydict: list of arrays
619
Feature dictionary atoms.
620
- alpha: float, default=0.5
621
622
Returns:
623
- coefficients: ndarray
624
- log: dict
625
"""
626
```
627
628
## Utility Functions
629
630
```python { .api }
631
def ot.gromov.init_matrix(C1, C2, p, q, loss_fun='square_loss', random_state=None):
632
"""
633
Initialize transport matrix for GW algorithms.
634
635
Parameters:
636
- C1: array-like, shape (n1, n1)
637
- C2: array-like, shape (n2, n2)
638
- p: array-like, shape (n1,)
639
- q: array-like, shape (n2,)
640
- loss_fun: str or callable, default='square_loss'
641
- random_state: int, optional
642
643
Returns:
644
- G0: ndarray, shape (n1, n2)
645
Initial transport matrix.
646
"""
647
648
def ot.gromov.tensor_product(constC, hC1, hC2, T):
649
"""
650
Compute tensor product for GW gradient computation.
651
652
Parameters:
653
- constC: ndarray
654
Constant term in GW formulation.
655
- hC1: ndarray
656
Source structure term.
657
- hC2: ndarray
658
Target structure term.
659
- T: ndarray
660
Current transport plan.
661
662
Returns:
663
- tensor_prod: ndarray
664
Tensor product result.
665
"""
666
667
def ot.gromov.gwloss(constC, hC1, hC2, T):
668
"""
669
Compute Gromov-Wasserstein loss function value.
670
671
Parameters:
672
- constC: ndarray
673
- hC1: ndarray
674
- hC2: ndarray
675
- T: ndarray
676
Transport plan.
677
678
Returns:
679
- loss: float
680
GW loss value.
681
"""
682
683
def ot.gromov.gwggrad(constC, hC1, hC2, T):
684
"""
685
Compute Gromov-Wasserstein gradient.
686
687
Parameters:
688
- constC: ndarray
689
- hC1: ndarray
690
- hC2: ndarray
691
- T: ndarray
692
693
Returns:
694
- gradient: ndarray
695
GW objective gradient.
696
"""
697
698
def ot.gromov.update_barycenter_structure(Ts, Cs, lambdas, p, loss_fun='square_loss'):
699
"""
700
Update barycenter structure matrix.
701
702
Parameters:
703
- Ts: list of arrays
704
Transport plans to input spaces.
705
- Cs: list of arrays
706
Input structure matrices.
707
- lambdas: array-like
708
Barycenter weights.
709
- p: array-like
710
Barycenter distribution.
711
- loss_fun: str or callable, default='square_loss'
712
713
Returns:
714
- C_barycenter: ndarray
715
Updated barycenter structure.
716
"""
717
718
def ot.gromov.update_barycenter_feature(Ts, Ys, lambdas, p):
719
"""
720
Update barycenter feature matrix.
721
722
Parameters:
723
- Ts: list of arrays
724
Transport plans.
725
- Ys: list of arrays
726
Input feature matrices.
727
- lambdas: array-like
728
- p: array-like
729
730
Returns:
731
- Y_barycenter: ndarray
732
Updated barycenter features.
733
"""
734
```
735
736
## Usage Examples
737
738
### Basic Gromov-Wasserstein
739
```python
740
import ot
741
import numpy as np
742
743
# Create structure matrices (e.g., distance matrices)
744
n1, n2 = 10, 15
745
C1 = np.random.rand(n1, n1)
746
C1 = (C1 + C1.T) / 2 # Make symmetric
747
C2 = np.random.rand(n2, n2)
748
C2 = (C2 + C2.T) / 2
749
750
# Create distributions
751
p = ot.unif(n1)
752
q = ot.unif(n2)
753
754
# Compute GW distance
755
gw_plan = ot.gromov.gromov_wasserstein(C1, C2, p, q, verbose=True)
756
gw_dist = ot.gromov.gromov_wasserstein2(C1, C2, p, q)
757
758
print(f"GW distance: {gw_dist}")
759
print(f"Transport plan shape: {gw_plan.shape}")
760
```
761
762
### Fused Gromov-Wasserstein
763
```python
764
# Feature cost matrix
765
d = 3
766
X1 = np.random.randn(n1, d)
767
X2 = np.random.randn(n2, d)
768
M = ot.dist(X1, X2)
769
770
# Structure-feature trade-off
771
alpha = 0.7 # More weight on structure
772
773
# Compute FGW
774
fgw_plan = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, alpha=alpha)
775
fgw_dist = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, alpha=alpha)
776
777
print(f"FGW distance: {fgw_dist}")
778
```
779
780
### GW Barycenter
781
```python
782
# Multiple structures
783
n_spaces = 5
784
Cs = [np.random.rand(8, 8) for _ in range(n_spaces)]
785
Cs = [(C + C.T)/2 for C in Cs] # Make symmetric
786
787
ps = [ot.unif(8) for _ in range(n_spaces)]
788
lambdas = ot.unif(n_spaces)
789
790
# Barycenter parameters
791
N = 6 # Barycenter size
792
p_barycenter = ot.unif(N)
793
794
# Compute barycenter
795
C_barycenter = ot.gromov.gromov_barycenters(N, Cs, ps, p_barycenter, lambdas, verbose=True)
796
797
print(f"Barycenter structure shape: {C_barycenter.shape}")
798
```
799
800
### Entropic GW
801
```python
802
# Add entropic regularization
803
epsilon = 0.05
804
805
# Compute entropic GW
806
egw_plan = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, epsilon=epsilon)
807
egw_dist = ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, epsilon=epsilon)
808
809
print(f"Entropic GW distance: {egw_dist}")
810
```
811
812
### Partial GW for Outlier Robustness
813
```python
814
# Transport only 70% of mass
815
m = 0.7
816
817
# Compute partial GW
818
pgw_plan = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m)
819
pgw_dist = ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=m)
820
821
print(f"Partial GW distance: {pgw_dist}")
822
print(f"Transported mass: {np.sum(pgw_plan)}")
823
```
824
825
## Quantized and Sampling Methods
826
827
Large-scale methods using graph partitioning, quantization, and sampling approaches.
828
829
```python { .api }
830
def quantized_fused_gromov_wasserstein(C1, C2, Y1, Y2, a=None, b=None, alpha=0.5, reg=0.1, num_node_class=8, **kwargs):
831
"""
832
Solve quantized FGW using graph partitioning for computational efficiency.
833
834
Parameters:
835
- C1, C2: array-like, structure matrices
836
- Y1, Y2: array-like, feature matrices
837
- a, b: array-like, distributions
838
- alpha: float, structure/feature weight
839
- reg: float, regularization parameter
840
- num_node_class: int, number of partitions
841
842
Returns:
843
- quantized transport plan
844
"""
845
846
def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0.0, rank=10, numItermax=100, stopThr=1e-5, log=False):
847
"""
848
Solve GW using low-rank factorization for large-scale problems.
849
"""
850
851
def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', nb_samples_grad=100, log=False, **kwargs):
852
"""
853
Solve GW using sampling for gradient computation.
854
"""
855
856
def get_graph_partition(C1, num_node_class=8, part_method='louvain'):
857
"""
858
Partition graph for quantized methods.
859
"""
860
```
861
862
## Unbalanced Methods
863
864
Unbalanced variants allowing different total masses.
865
866
```python { .api }
867
def fused_unbalanced_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, rho=1.0, rho2=1.0, **kwargs):
868
"""
869
Solve unbalanced FGW with marginal relaxation penalties.
870
871
Parameters:
872
- M: array-like, feature cost matrix
873
- C1, C2: array-like, structure matrices
874
- p, q: array-like, measures (can have different masses)
875
- epsilon: float, entropic regularization
876
- alpha: float, structure/feature trade-off
877
- rho, rho2: float, marginal relaxation penalties
878
879
Returns:
880
- unbalanced transport plan
881
"""
882
883
def unbalanced_co_optimal_transport(X_s, X_t, C1, C2, p, q, epsilon=0.1, rho=1.0, rho2=1.0, **kwargs):
884
"""
885
Solve unbalanced co-optimal transport.
886
"""
887
```
888
889
## Import Statements
890
891
```python
892
import ot.gromov
893
from ot.gromov import gromov_wasserstein, gromov_wasserstein2
894
from ot.gromov import fused_gromov_wasserstein, fused_gromov_wasserstein2
895
from ot.gromov import gromov_barycenters, fgw_barycenters
896
from ot.gromov import entropic_gromov_wasserstein, entropic_fused_gromov_wasserstein
897
from ot.gromov import semirelaxed_gromov_wasserstein, partial_gromov_wasserstein
898
from ot.gromov import gromov_wasserstein_dictionary_learning, quantized_fused_gromov_wasserstein
899
```
900
901
The `ot.gromov` module provides powerful tools for structured optimal transport, enabling comparison of data with internal geometric structure such as graphs, point clouds, and other metric spaces where traditional optimal transport is not directly applicable.