or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

gromov-wasserstein.mddocs/

0

# Gromov-Wasserstein Distances

1

2

The `ot.gromov` module provides algorithms for computing Gromov-Wasserstein (GW) distances and their variants. These methods enable optimal transport between structured data by comparing the internal geometry of metric spaces rather than requiring a common embedding space.

3

4

## Core Gromov-Wasserstein Functions

5

6

### Basic GW Distance Computation

7

8

```python { .api }

9

def ot.gromov.gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

10

"""

11

Compute Gromov-Wasserstein distance between two metric spaces.

12

13

Solves the quadratic assignment problem to find optimal correspondences

14

between points in different metric spaces by preserving pairwise distances.

15

16

Parameters:

17

- C1: array-like, shape (n1, n1)

18

Intra-structure cost matrix for source space (distances/similarities).

19

- C2: array-like, shape (n2, n2)

20

Intra-structure cost matrix for target space.

21

- p: array-like, shape (n1,)

22

Distribution over source space. Must be positive and sum to 1.

23

- q: array-like, shape (n2,)

24

Distribution over target space. Must be positive and sum to 1.

25

- loss_fun: str or callable, default='square_loss'

26

Loss function for structure preservation. Options: 'square_loss', 'kl_loss'

27

or custom function with signature loss_fun(C1, C2, T).

28

- alpha: float, default=0.5

29

Step size parameter for the gradient descent algorithm.

30

- armijo: bool, default=False

31

Use Armijo line search for adaptive step size.

32

- log: bool, default=False

33

Return optimization log with convergence details.

34

- max_iter: int, default=1000

35

Maximum number of iterations.

36

- tol_rel: float, default=1e-9

37

Relative tolerance for convergence.

38

- tol_abs: float, default=1e-9

39

Absolute tolerance for convergence.

40

41

Returns:

42

- transport_plan: ndarray, shape (n1, n2)

43

Optimal GW transport plan between the two spaces.

44

- log: dict (if log=True)

45

Contains 'gw_dist': GW distance, 'err': convergence errors,

46

'T': transport plans at each iteration.

47

"""

48

49

def ot.gromov.gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

50

"""

51

Compute Gromov-Wasserstein squared distance (cost only).

52

53

More efficient when only the distance value is needed.

54

55

Parameters: Same as gromov_wasserstein()

56

57

Returns:

58

- gw_distance: float

59

Gromov-Wasserstein distance between the two spaces.

60

- log: dict (if log=True)

61

"""

62

63

def ot.gromov.solve_gromov_linesearch(C1, C2, p, q, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):

64

"""

65

Solve GW problem with automatic line search for optimal step size.

66

67

Parameters:

68

- C1, C2: array-like

69

Cost matrices for source and target spaces.

70

- p, q: array-like

71

Distributions over source and target spaces.

72

- loss_fun: str or callable

73

Loss function for GW computation.

74

- alpha_min: float, optional

75

Minimum step size for line search.

76

- alpha_max: float, optional

77

Maximum step size for line search.

78

- log: bool, default=False

79

- numItermax: int, default=1000

80

- stopThr: float, default=1e-9

81

- verbose: bool, default=False

82

83

Returns:

84

- transport_plan: ndarray

85

- log: dict (if log=True)

86

"""

87

```

88

89

### Fused Gromov-Wasserstein

90

91

```python { .api }

92

def ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

93

"""

94

Compute Fused Gromov-Wasserstein distance combining structure and features.

95

96

Combines standard optimal transport (based on feature cost M) with

97

Gromov-Wasserstein transport (based on structure costs C1, C2).

98

99

Parameters:

100

- M: array-like, shape (n1, n2)

101

Feature cost matrix between source and target samples.

102

- C1: array-like, shape (n1, n1)

103

Intra-structure cost matrix for source space.

104

- C2: array-like, shape (n2, n2)

105

Intra-structure cost matrix for target space.

106

- p: array-like, shape (n1,)

107

Source distribution.

108

- q: array-like, shape (n2,)

109

Target distribution.

110

- loss_fun: str or callable, default='square_loss'

111

Loss function for structure preservation.

112

- alpha: float, default=0.5

113

Trade-off parameter between structure (α) and features (1-α).

114

α=1 gives pure GW, α=0 gives pure Wasserstein.

115

- armijo: bool, default=False

116

Use Armijo line search.

117

- log: bool, default=False

118

- max_iter: int, default=1000

119

- tol_rel: float, default=1e-9

120

- tol_abs: float, default=1e-9

121

122

Returns:

123

- transport_plan: ndarray, shape (n1, n2)

124

Optimal FGW transport plan.

125

- log: dict (if log=True)

126

"""

127

128

def ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

129

"""

130

Compute Fused Gromov-Wasserstein squared distance (cost only).

131

132

Parameters: Same as fused_gromov_wasserstein()

133

134

Returns:

135

- fgw_distance: float

136

- log: dict (if log=True)

137

"""

138

```

139

140

## Barycenter Algorithms

141

142

```python { .api }

143

def ot.gromov.gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None, **kwargs):

144

"""

145

Compute Gromov-Wasserstein barycenter of multiple metric spaces.

146

147

Finds the barycenter space that minimizes the sum of GW distances

148

to all input spaces, optimizing both the structure and distribution.

149

150

Parameters:

151

- N: int

152

Size of the barycenter space (number of points).

153

- Cs: list of arrays

154

List of intra-structure cost matrices for input spaces.

155

Each Cs[i] has shape (ni, ni).

156

- ps: list of arrays

157

List of distributions for input spaces.

158

Each ps[i] has shape (ni,).

159

- p: array-like, shape (N,)

160

Distribution for the barycenter space.

161

- lambdas: array-like, shape (n_spaces,)

162

Weights for combining input spaces in barycenter.

163

- loss_fun: str or callable, default='square_loss'

164

Loss function for GW computation.

165

- max_iter: int, default=1000

166

Maximum iterations for barycenter optimization.

167

- tol: float, default=1e-9

168

Convergence tolerance.

169

- verbose: bool, default=False

170

Print optimization information.

171

- log: bool, default=False

172

Return optimization log.

173

- init_C: array-like, shape (N, N), optional

174

Initial barycenter structure matrix. Random if None.

175

- random_state: int, optional

176

Random seed for reproducible initialization.

177

178

Returns:

179

- barycenter_structure: ndarray, shape (N, N)

180

Optimal barycenter intra-structure cost matrix.

181

- log: dict (if log=True)

182

Contains convergence information and transport plans.

183

"""

184

185

def ot.gromov.fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):

186

"""

187

Compute Fused Gromov-Wasserstein barycenter with features and structure.

188

189

Parameters:

190

- N: int

191

Barycenter size.

192

- Ys: list of arrays

193

List of feature matrices for input spaces.

194

- Cs: list of arrays

195

List of structure matrices for input spaces.

196

- ps: list of arrays

197

List of distributions for input spaces.

198

- lambdas: array-like

199

Weights for space combination.

200

- alpha: float

201

Trade-off between structure and features.

202

- fixed_structure: bool, default=False

203

Whether to fix the barycenter structure.

204

- fixed_features: bool, default=False

205

Whether to fix the barycenter features.

206

- p: array-like, optional

207

Barycenter distribution.

208

- loss_fun: str or callable, default='square_loss'

209

- max_iter: int, default=100

210

- tol: float, default=1e-9

211

- verbose: bool, default=False

212

- log: bool, default=False

213

- init_C: array-like, optional

214

Initial barycenter structure.

215

- init_X: array-like, optional

216

Initial barycenter features.

217

- random_state: int, optional

218

219

Returns:

220

- barycenter_features: ndarray, shape (N, d)

221

- barycenter_structure: ndarray, shape (N, N)

222

- log: dict (if log=True)

223

"""

224

```

225

226

## Entropic Regularized Methods

227

228

```python { .api }

229

def ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):

230

"""

231

Compute entropic regularized Gromov-Wasserstein distance.

232

233

Combines GW formulation with entropic regularization for better

234

computational properties and differentiability.

235

236

Parameters:

237

- C1: array-like, shape (n1, n1)

238

Source structure matrix.

239

- C2: array-like, shape (n2, n2)

240

Target structure matrix.

241

- p: array-like, shape (n1,)

242

Source distribution.

243

- q: array-like, shape (n2,)

244

Target distribution.

245

- loss_fun: str or callable, default='square_loss'

246

- epsilon: float, default=0.1

247

Entropic regularization parameter.

248

- symmetric: bool, optional

249

Whether loss function is symmetric.

250

- G0: array-like, optional

251

Initial transport plan.

252

- max_iter: int, default=1000

253

- tol: float, default=1e-9

254

- verbose: bool, default=False

255

- log: bool, default=False

256

257

Returns:

258

- transport_plan: ndarray, shape (n1, n2)

259

- log: dict (if log=True)

260

"""

261

262

def ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):

263

"""

264

Compute entropic regularized GW distance (cost only).

265

266

Parameters: Same as entropic_gromov_wasserstein()

267

268

Returns:

269

- gw_distance: float

270

- log: dict (if log=True)

271

"""

272

273

def ot.gromov.entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):

274

"""

275

Compute entropic regularized GW barycenters.

276

277

Parameters:

278

- N: int

279

Barycenter size.

280

- Cs: list of arrays

281

Structure matrices.

282

- ps: list of arrays

283

Input distributions.

284

- p: array-like

285

Barycenter distribution.

286

- lambdas: array-like

287

Combination weights.

288

- loss_fun: str or callable, default='square_loss'

289

- epsilon: float, default=0.1

290

Entropic regularization.

291

- symmetric: bool, default=True

292

- max_iter: int, default=1000

293

- tol: float, default=1e-9

294

- verbose: bool, default=False

295

- log: bool, default=False

296

- init_C: array-like, optional

297

- random_state: int, optional

298

299

Returns:

300

- barycenter_structure: ndarray, shape (N, N)

301

- log: dict (if log=True)

302

"""

303

304

def ot.gromov.entropic_fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):

305

"""

306

Compute entropic regularized Fused GW distance.

307

308

Parameters:

309

- M: array-like, shape (n1, n2)

310

Feature cost matrix.

311

- C1, C2: array-like

312

Structure matrices.

313

- p, q: array-like

314

Distributions.

315

- loss_fun: str or callable, default='square_loss'

316

- epsilon: float, default=0.1

317

Entropic regularization.

318

- alpha: float, default=0.5

319

Structure/feature trade-off.

320

- symmetric: bool, optional

321

- G0: array-like, optional

322

Initial transport plan.

323

- max_iter: int, default=1000

324

- tol: float, default=1e-9

325

- verbose: bool, default=False

326

- log: bool, default=False

327

328

Returns:

329

- transport_plan: ndarray, shape (n1, n2)

330

- log: dict (if log=True)

331

"""

332

333

def ot.gromov.entropic_fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):

334

"""

335

Compute entropic regularized FGW distance (cost only).

336

337

Parameters: Same as entropic_fused_gromov_wasserstein()

338

339

Returns:

340

- fgw_distance: float

341

- log: dict (if log=True)

342

"""

343

344

def ot.gromov.entropic_fused_gromov_barycenters(N, Ys, Cs, ps, lambdas, alpha, epsilon=0.1, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None):

345

"""

346

Compute entropic regularized FGW barycenters.

347

348

Parameters: Similar to fgw_barycenters() with additional epsilon parameter

349

350

Returns:

351

- barycenter_features: ndarray

352

- barycenter_structure: ndarray

353

- log: dict (if log=True)

354

"""

355

```

356

357

## Semi-relaxed Methods

358

359

```python { .api }

360

def ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

361

"""

362

Compute semi-relaxed Gromov-Wasserstein distance.

363

364

Relaxes the constraint on one marginal, allowing for partial transport

365

from source to target while preserving target marginal.

366

367

Parameters:

368

- C1: array-like, shape (n1, n1)

369

Source structure matrix.

370

- C2: array-like, shape (n2, n2)

371

Target structure matrix.

372

- p: array-like, shape (n1,)

373

Source distribution (will be relaxed).

374

- loss_fun: str or callable, default='square_loss'

375

- symmetric: bool, optional

376

- alpha: float, default=0.5

377

Step size parameter.

378

- G0: array-like, optional

379

Initial transport plan.

380

- log: bool, default=False

381

- max_iter: int, default=1000

382

- tol_rel: float, default=1e-9

383

- tol_abs: float, default=1e-9

384

385

Returns:

386

- transport_plan: ndarray, shape (n1, n2)

387

Semi-relaxed transport plan.

388

- log: dict (if log=True)

389

"""

390

391

def ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

392

"""

393

Compute semi-relaxed GW distance (cost only).

394

395

Parameters: Same as semirelaxed_gromov_wasserstein()

396

397

Returns:

398

- sr_gw_distance: float

399

- log: dict (if log=True)

400

"""

401

402

def ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

403

"""

404

Compute semi-relaxed Fused GW distance.

405

406

Parameters:

407

- M: array-like, shape (n1, n2)

408

Feature cost matrix.

409

- Other parameters same as semirelaxed_gromov_wasserstein()

410

411

Returns:

412

- transport_plan: ndarray, shape (n1, n2)

413

- log: dict (if log=True)

414

"""

415

416

def ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

417

"""

418

Compute semi-relaxed FGW distance (cost only).

419

420

Parameters: Same as semirelaxed_fused_gromov_wasserstein()

421

422

Returns:

423

- sr_fgw_distance: float

424

- log: dict (if log=True)

425

"""

426

427

def ot.gromov.solve_semirelaxed_gromov_linesearch(C1, C2, p, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):

428

"""

429

Solve semi-relaxed GW with line search optimization.

430

431

Parameters: Similar to solve_gromov_linesearch()

432

433

Returns:

434

- transport_plan: ndarray

435

- log: dict (if log=True)

436

"""

437

```

438

439

## Partial Methods

440

441

```python { .api }

442

def ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

443

"""

444

Compute partial Gromov-Wasserstein distance.

445

446

Allows transport of only a fraction of the total mass, useful when

447

spaces have different sizes or contain outliers.

448

449

Parameters:

450

- C1: array-like, shape (n1, n1)

451

Source structure matrix.

452

- C2: array-like, shape (n2, n2)

453

Target structure matrix.

454

- p: array-like, shape (n1,)

455

Source distribution.

456

- q: array-like, shape (n2,)

457

Target distribution.

458

- m: float, optional

459

Fraction of mass to transport (default: min(sum(p), sum(q))).

460

- loss_fun: str or callable, default='square_loss'

461

- alpha: float, default=0.5

462

- armijo: bool, default=False

463

- log: bool, default=False

464

- max_iter: int, default=1000

465

- tol_rel: float, default=1e-9

466

- tol_abs: float, default=1e-9

467

468

Returns:

469

- transport_plan: ndarray, shape (n1, n2)

470

Partial GW transport plan.

471

- log: dict (if log=True)

472

"""

473

474

def ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):

475

"""

476

Compute partial GW distance (cost only).

477

478

Parameters: Same as partial_gromov_wasserstein()

479

480

Returns:

481

- partial_gw_distance: float

482

- log: dict (if log=True)

483

"""

484

485

def ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):

486

"""

487

Compute entropic regularized partial GW distance.

488

489

Parameters:

490

- C1, C2: array-like

491

Structure matrices.

492

- p, q: array-like

493

Distributions.

494

- reg: float

495

Entropic regularization parameter.

496

- m: float, optional

497

Mass to transport.

498

- loss_fun: str or callable, default='square_loss'

499

- G0: array-like, optional

500

- max_iter: int, default=1000

501

- tol: float, default=1e-9

502

- verbose: bool, default=False

503

- log: bool, default=False

504

505

Returns:

506

- transport_plan: ndarray

507

- log: dict (if log=True)

508

"""

509

510

def ot.gromov.entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):

511

"""

512

Compute entropic regularized partial GW distance (cost only).

513

514

Parameters: Same as entropic_partial_gromov_wasserstein()

515

516

Returns:

517

- partial_gw_distance: float

518

- log: dict (if log=True)

519

"""

520

```

521

522

## Dictionary Learning Methods

523

524

```python { .api }

525

def ot.gromov.gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0.0, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None, **kwargs):

526

"""

527

Learn dictionary of structures using GW distances.

528

529

Learns a dictionary of prototype structures that can represent

530

input structures as sparse combinations.

531

532

Parameters:

533

- Cs: list of arrays

534

Input structure matrices to learn from.

535

- D: int

536

Dictionary size (number of atoms).

537

- nt: int

538

Size of each dictionary atom.

539

- reg: float, default=0.0

540

Entropic regularization for GW computation.

541

- ps: list of arrays, optional

542

Distributions for input structures.

543

- q: array-like, optional

544

Distribution for dictionary atoms.

545

- epochs: int, default=20

546

Number of learning epochs.

547

- batch_size: int, default=32

548

Mini-batch size for learning.

549

- learning_rate: float, default=1.0

550

Learning rate for dictionary updates.

551

- proj_sparse_regul: float, default=0.1

552

Sparsity regularization for projections.

553

- verbose: bool, default=False

554

- random_state: int, optional

555

556

Returns:

557

- dictionary: list of arrays

558

Learned dictionary of structure matrices.

559

- log: dict

560

Learning statistics and convergence information.

561

"""

562

563

def ot.gromov.gromov_wasserstein_linear_unmixing(C, Cdict, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False, **kwargs):

564

"""

565

Unmix structure using learned GW dictionary.

566

567

Decomposes input structure as sparse combination of dictionary atoms.

568

569

Parameters:

570

- C: array-like, shape (n, n)

571

Structure matrix to unmix.

572

- Cdict: list of arrays

573

Dictionary of structure atoms.

574

- reg: float, default=0.0

575

Entropic regularization.

576

- p: array-like, optional

577

Distribution for input structure.

578

- q: array-like, optional

579

Distribution for dictionary atoms.

580

- tol_outer: float, default=1e-6

581

Outer loop tolerance.

582

- tol_inner: float, default=1e-6

583

Inner loop tolerance.

584

- max_iter_outer: int, default=20

585

- max_iter_inner: int, default=200

586

- verbose: bool, default=False

587

588

Returns:

589

- coefficients: ndarray

590

Sparse coefficients for dictionary combination.

591

- log: dict

592

Unmixing optimization information.

593

"""

594

595

def ot.gromov.fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, reg=0.0, alpha=0.5, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None):

596

"""

597

Learn dictionary for FGW (structure + features).

598

599

Parameters: Extends gromov_wasserstein_dictionary_learning() with:

600

- Ys: list of arrays

601

Feature matrices for input data.

602

- alpha: float, default=0.5

603

Structure/feature trade-off.

604

605

Returns:

606

- structure_dictionary: list of arrays

607

- feature_dictionary: list of arrays

608

- log: dict

609

"""

610

611

def ot.gromov.fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha=0.5, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False):

612

"""

613

Unmix FGW data using learned dictionary.

614

615

Parameters: Extends gromov_wasserstein_linear_unmixing() with:

616

- Y: array-like

617

Feature matrix to unmix.

618

- Ydict: list of arrays

619

Feature dictionary atoms.

620

- alpha: float, default=0.5

621

622

Returns:

623

- coefficients: ndarray

624

- log: dict

625

"""

626

```

627

628

## Utility Functions

629

630

```python { .api }

631

def ot.gromov.init_matrix(C1, C2, p, q, loss_fun='square_loss', random_state=None):

632

"""

633

Initialize transport matrix for GW algorithms.

634

635

Parameters:

636

- C1: array-like, shape (n1, n1)

637

- C2: array-like, shape (n2, n2)

638

- p: array-like, shape (n1,)

639

- q: array-like, shape (n2,)

640

- loss_fun: str or callable, default='square_loss'

641

- random_state: int, optional

642

643

Returns:

644

- G0: ndarray, shape (n1, n2)

645

Initial transport matrix.

646

"""

647

648

def ot.gromov.tensor_product(constC, hC1, hC2, T):

649

"""

650

Compute tensor product for GW gradient computation.

651

652

Parameters:

653

- constC: ndarray

654

Constant term in GW formulation.

655

- hC1: ndarray

656

Source structure term.

657

- hC2: ndarray

658

Target structure term.

659

- T: ndarray

660

Current transport plan.

661

662

Returns:

663

- tensor_prod: ndarray

664

Tensor product result.

665

"""

666

667

def ot.gromov.gwloss(constC, hC1, hC2, T):

668

"""

669

Compute Gromov-Wasserstein loss function value.

670

671

Parameters:

672

- constC: ndarray

673

- hC1: ndarray

674

- hC2: ndarray

675

- T: ndarray

676

Transport plan.

677

678

Returns:

679

- loss: float

680

GW loss value.

681

"""

682

683

def ot.gromov.gwggrad(constC, hC1, hC2, T):

684

"""

685

Compute Gromov-Wasserstein gradient.

686

687

Parameters:

688

- constC: ndarray

689

- hC1: ndarray

690

- hC2: ndarray

691

- T: ndarray

692

693

Returns:

694

- gradient: ndarray

695

GW objective gradient.

696

"""

697

698

def ot.gromov.update_barycenter_structure(Ts, Cs, lambdas, p, loss_fun='square_loss'):

699

"""

700

Update barycenter structure matrix.

701

702

Parameters:

703

- Ts: list of arrays

704

Transport plans to input spaces.

705

- Cs: list of arrays

706

Input structure matrices.

707

- lambdas: array-like

708

Barycenter weights.

709

- p: array-like

710

Barycenter distribution.

711

- loss_fun: str or callable, default='square_loss'

712

713

Returns:

714

- C_barycenter: ndarray

715

Updated barycenter structure.

716

"""

717

718

def ot.gromov.update_barycenter_feature(Ts, Ys, lambdas, p):

719

"""

720

Update barycenter feature matrix.

721

722

Parameters:

723

- Ts: list of arrays

724

Transport plans.

725

- Ys: list of arrays

726

Input feature matrices.

727

- lambdas: array-like

728

- p: array-like

729

730

Returns:

731

- Y_barycenter: ndarray

732

Updated barycenter features.

733

"""

734

```

735

736

## Usage Examples

737

738

### Basic Gromov-Wasserstein

739

```python

740

import ot

741

import numpy as np

742

743

# Create structure matrices (e.g., distance matrices)

744

n1, n2 = 10, 15

745

C1 = np.random.rand(n1, n1)

746

C1 = (C1 + C1.T) / 2 # Make symmetric

747

C2 = np.random.rand(n2, n2)

748

C2 = (C2 + C2.T) / 2

749

750

# Create distributions

751

p = ot.unif(n1)

752

q = ot.unif(n2)

753

754

# Compute GW distance

755

gw_plan = ot.gromov.gromov_wasserstein(C1, C2, p, q, verbose=True)

756

gw_dist = ot.gromov.gromov_wasserstein2(C1, C2, p, q)

757

758

print(f"GW distance: {gw_dist}")

759

print(f"Transport plan shape: {gw_plan.shape}")

760

```

761

762

### Fused Gromov-Wasserstein

763

```python

764

# Feature cost matrix

765

d = 3

766

X1 = np.random.randn(n1, d)

767

X2 = np.random.randn(n2, d)

768

M = ot.dist(X1, X2)

769

770

# Structure-feature trade-off

771

alpha = 0.7 # More weight on structure

772

773

# Compute FGW

774

fgw_plan = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, alpha=alpha)

775

fgw_dist = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, alpha=alpha)

776

777

print(f"FGW distance: {fgw_dist}")

778

```

779

780

### GW Barycenter

781

```python

782

# Multiple structures

783

n_spaces = 5

784

Cs = [np.random.rand(8, 8) for _ in range(n_spaces)]

785

Cs = [(C + C.T)/2 for C in Cs] # Make symmetric

786

787

ps = [ot.unif(8) for _ in range(n_spaces)]

788

lambdas = ot.unif(n_spaces)

789

790

# Barycenter parameters

791

N = 6 # Barycenter size

792

p_barycenter = ot.unif(N)

793

794

# Compute barycenter

795

C_barycenter = ot.gromov.gromov_barycenters(N, Cs, ps, p_barycenter, lambdas, verbose=True)

796

797

print(f"Barycenter structure shape: {C_barycenter.shape}")

798

```

799

800

### Entropic GW

801

```python

802

# Add entropic regularization

803

epsilon = 0.05

804

805

# Compute entropic GW

806

egw_plan = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, epsilon=epsilon)

807

egw_dist = ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, epsilon=epsilon)

808

809

print(f"Entropic GW distance: {egw_dist}")

810

```

811

812

### Partial GW for Outlier Robustness

813

```python

814

# Transport only 70% of mass

815

m = 0.7

816

817

# Compute partial GW

818

pgw_plan = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m)

819

pgw_dist = ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=m)

820

821

print(f"Partial GW distance: {pgw_dist}")

822

print(f"Transported mass: {np.sum(pgw_plan)}")

823

```

824

825

## Quantized and Sampling Methods

826

827

Large-scale methods using graph partitioning, quantization, and sampling approaches.

828

829

```python { .api }

830

def quantized_fused_gromov_wasserstein(C1, C2, Y1, Y2, a=None, b=None, alpha=0.5, reg=0.1, num_node_class=8, **kwargs):

831

"""

832

Solve quantized FGW using graph partitioning for computational efficiency.

833

834

Parameters:

835

- C1, C2: array-like, structure matrices

836

- Y1, Y2: array-like, feature matrices

837

- a, b: array-like, distributions

838

- alpha: float, structure/feature weight

839

- reg: float, regularization parameter

840

- num_node_class: int, number of partitions

841

842

Returns:

843

- quantized transport plan

844

"""

845

846

def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0.0, rank=10, numItermax=100, stopThr=1e-5, log=False):

847

"""

848

Solve GW using low-rank factorization for large-scale problems.

849

"""

850

851

def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', nb_samples_grad=100, log=False, **kwargs):

852

"""

853

Solve GW using sampling for gradient computation.

854

"""

855

856

def get_graph_partition(C1, num_node_class=8, part_method='louvain'):

857

"""

858

Partition graph for quantized methods.

859

"""

860

```

861

862

## Unbalanced Methods

863

864

Unbalanced variants allowing different total masses.

865

866

```python { .api }

867

def fused_unbalanced_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, rho=1.0, rho2=1.0, **kwargs):

868

"""

869

Solve unbalanced FGW with marginal relaxation penalties.

870

871

Parameters:

872

- M: array-like, feature cost matrix

873

- C1, C2: array-like, structure matrices

874

- p, q: array-like, measures (can have different masses)

875

- epsilon: float, entropic regularization

876

- alpha: float, structure/feature trade-off

877

- rho, rho2: float, marginal relaxation penalties

878

879

Returns:

880

- unbalanced transport plan

881

"""

882

883

def unbalanced_co_optimal_transport(X_s, X_t, C1, C2, p, q, epsilon=0.1, rho=1.0, rho2=1.0, **kwargs):

884

"""

885

Solve unbalanced co-optimal transport.

886

"""

887

```

888

889

## Import Statements

890

891

```python

892

import ot.gromov

893

from ot.gromov import gromov_wasserstein, gromov_wasserstein2

894

from ot.gromov import fused_gromov_wasserstein, fused_gromov_wasserstein2

895

from ot.gromov import gromov_barycenters, fgw_barycenters

896

from ot.gromov import entropic_gromov_wasserstein, entropic_fused_gromov_wasserstein

897

from ot.gromov import semirelaxed_gromov_wasserstein, partial_gromov_wasserstein

898

from ot.gromov import gromov_wasserstein_dictionary_learning, quantized_fused_gromov_wasserstein

899

```

900

901

The `ot.gromov` module provides powerful tools for structured optimal transport, enabling comparison of data with internal geometric structure such as graphs, point clouds, and other metric spaces where traditional optimal transport is not directly applicable.