or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

entropic-transport.mddocs/

0

# Entropic Regularized Transport

1

2

The `ot.bregman` module provides algorithms for solving entropic regularized optimal transport problems using Bregman projections. The Sinkhorn algorithm and its variants are the core methods, offering computational advantages over exact linear programming approaches while maintaining good approximation quality.

3

4

## Core Sinkhorn Algorithms

5

6

### Standard Sinkhorn Algorithm

7

8

```python { .api }

9

def ot.bregman.sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, warn=True, warmstart=None, **kwargs):

10

"""

11

Solve entropic regularized optimal transport using Sinkhorn algorithm.

12

13

The Sinkhorn algorithm iteratively projects onto marginal constraints to find

14

the optimal transport plan for the entropy-regularized problem:

15

min <P,M> + reg * KL(P|K) subject to P1=a, P^T1=b

16

where K = exp(-M/reg).

17

18

Parameters:

19

- a: array-like, shape (n_samples_source,)

20

Source distribution (histogram). Must be positive and sum to 1.

21

- b: array-like, shape (n_samples_target,)

22

Target distribution (histogram). Must be positive and sum to 1.

23

- M: array-like, shape (n_samples_source, n_samples_target)

24

Ground cost matrix.

25

- reg: float

26

Regularization parameter (>0). Lower values give solutions closer to EMD.

27

- method: str, default='sinkhorn'

28

Algorithm variant. Options: 'sinkhorn', 'sinkhorn_log', 'sinkhorn_stabilized',

29

'sinkhorn_epsilon_scaling', 'greenkhorn', 'screenkhorn'

30

- numItermax: int, default=1000

31

Maximum number of iterations.

32

- stopThr: float, default=1e-9

33

Convergence threshold on marginal difference.

34

- verbose: bool, default=False

35

Print iteration information.

36

- log: bool, default=False

37

Return optimization log with convergence details.

38

- warn: bool, default=True

39

Warn if algorithm doesn't converge.

40

- warmstart: tuple, default=None

41

Tuple (u, v) of dual variables for warm start initialization.

42

43

Returns:

44

- transport_plan: ndarray, shape (n_samples_source, n_samples_target)

45

Entropic optimal transport plan.

46

- log: dict (if log=True)

47

Contains 'err': convergence errors, 'niter': iterations used,

48

'u': source scaling, 'v': target scaling.

49

"""

50

51

def ot.bregman.sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs):

52

"""

53

Solve entropic regularized OT and return transport cost only.

54

55

More efficient than sinkhorn() when only the optimal value is needed.

56

57

Parameters: Same as sinkhorn()

58

59

Returns:

60

- cost: float

61

Entropic regularized transport cost.

62

- log: dict (if log=True)

63

"""

64

65

def ot.bregman.sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):

66

"""

67

Sinkhorn-Knopp algorithm for entropic optimal transport.

68

69

Classic formulation using multiplicative updates with diagonal scaling matrices.

70

71

Parameters: Similar to sinkhorn()

72

73

Returns:

74

- transport_plan: ndarray

75

- log: dict (if log=True)

76

"""

77

```

78

79

### Advanced Sinkhorn Variants

80

81

```python { .api }

82

def ot.bregman.sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):

83

"""

84

Sinkhorn algorithm in log-domain for numerical stability.

85

86

Performs computations in log space to avoid numerical overflow/underflow

87

issues when regularization parameter is small or cost matrix has large values.

88

89

Parameters: Same as sinkhorn()

90

91

Returns:

92

- transport_plan: ndarray

93

- log: dict (if log=True)

94

"""

95

96

def ot.bregman.sinkhorn_stabilized(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, tau=1e3, **kwargs):

97

"""

98

Stabilized Sinkhorn algorithm with absorption technique.

99

100

Uses tau-absorption to prevent numerical overflow while maintaining

101

precision. Automatically switches between normal and log computations.

102

103

Parameters:

104

- Additional parameter:

105

- tau: float, default=1e3

106

Absorption threshold. When scaling factors exceed tau, algorithm

107

absorbs them into the dual variables.

108

109

Returns:

110

- transport_plan: ndarray

111

- log: dict (if log=True)

112

"""

113

114

def ot.bregman.sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, stopThr=1e-9, verbose=False, log=False, **kwargs):

115

"""

116

Epsilon-scaling Sinkhorn for better convergence with small regularization.

117

118

Starts with large regularization parameter and progressively decreases it

119

to the target value, using warm-start between scales.

120

121

Parameters:

122

- epsilon0: float, default=1e4

123

Initial (large) regularization parameter.

124

- numInnerItermax: int, default=100

125

Maximum iterations per epsilon scale.

126

- Other parameters same as sinkhorn()

127

128

Returns:

129

- transport_plan: ndarray

130

- log: dict (if log=True)

131

"""

132

133

def ot.bregman.greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):

134

"""

135

Greenkhorn algorithm for sparse optimal transport.

136

137

Coordinate-wise variant of Sinkhorn that updates one row/column at a time,

138

leading to sparse solutions suitable for large-scale problems.

139

140

Parameters: Same as sinkhorn() with typically larger numItermax

141

142

Returns:

143

- transport_plan: ndarray (often sparse)

144

- log: dict (if log=True)

145

"""

146

147

def ot.bregman.screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):

148

"""

149

Screenkhorn algorithm for large-scale optimal transport.

150

151

Uses screening techniques to identify and ignore negligible entries in the

152

transport matrix, significantly reducing computational cost for large problems.

153

154

Parameters:

155

- ns_budget: int, optional

156

Maximum number of active source samples.

157

- nt_budget: int, optional

158

Maximum number of active target samples.

159

- uniform: bool, default=False

160

Use uniform sampling for screening.

161

- restricted: bool, default=True

162

Use restricted Sinkhorn on screened samples.

163

- maxiter: int, default=10000

164

- maxfun: int, default=10000

165

Maximum function evaluations.

166

- pgtol: float, default=1e-09

167

Projected gradient tolerance.

168

169

Returns:

170

- transport_plan: ndarray

171

- log: dict (if log=True)

172

"""

173

```

174

175

## Barycenter Algorithms

176

177

### Standard Barycenters

178

179

```python { .api }

180

def ot.bregman.barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, **kwargs):

181

"""

182

Compute Wasserstein barycenter using entropic regularization.

183

184

Solves the multi-marginal optimal transport problem to find the barycenter

185

that minimizes the sum of regularized transport costs to all input distributions.

186

187

Parameters:

188

- A: array-like, shape (n_samples, n_distributions)

189

Input distributions as columns of matrix A.

190

- M: array-like, shape (n_samples, n_samples)

191

Ground cost matrix on barycenter support.

192

- reg: float

193

Entropic regularization parameter.

194

- weights: array-like, shape (n_distributions,), optional

195

Weights for barycenter combination. Default is uniform.

196

- method: str, default="sinkhorn"

197

Algorithm to use for transport computation.

198

- numItermax: int, default=10000

199

Maximum iterations for barycenter computation.

200

- stopThr: float, default=1e-4

201

Convergence threshold.

202

- verbose: bool, default=False

203

- log: bool, default=False

204

205

Returns:

206

- barycenter: ndarray, shape (n_samples,)

207

Wasserstein barycenter distribution.

208

- log: dict (if log=True)

209

Contains convergence information and transport plans.

210

"""

211

212

def ot.bregman.barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):

213

"""

214

Compute barycenter using Sinkhorn algorithm (alternative implementation).

215

216

Parameters: Same as barycenter()

217

218

Returns:

219

- barycenter: ndarray

220

- log: dict (if log=True)

221

"""

222

223

def ot.bregman.barycenter_stabilized(A, M, reg, tau=1e3, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):

224

"""

225

Compute barycenter using stabilized Sinkhorn algorithm.

226

227

Parameters:

228

- tau: float, default=1e3

229

Stabilization parameter.

230

- Other parameters same as barycenter()

231

232

Returns:

233

- barycenter: ndarray

234

- log: dict (if log=True)

235

"""

236

237

def ot.bregman.barycenter_debiased(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):

238

"""

239

Compute debiased Wasserstein barycenter.

240

241

Applies debiasing correction to reduce bias introduced by entropic regularization.

242

243

Parameters: Same as barycenter()

244

245

Returns:

246

- barycenter: ndarray

247

- log: dict (if log=True)

248

"""

249

```

250

251

### Free Support Barycenters

252

253

```python { .api }

254

def ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=False, **kwargs):

255

"""

256

Compute free-support Wasserstein barycenter using Sinkhorn algorithm.

257

258

Optimizes both barycenter weights and support locations simultaneously,

259

unlike fixed-support methods that only optimize weights.

260

261

Parameters:

262

- measures_locations: list of arrays

263

Support points for each input measure.

264

- measures_weights: list of arrays

265

Weights for each input measure.

266

- X_init: array-like, shape (k, d)

267

Initial barycenter support points.

268

- reg: float

269

Entropic regularization parameter.

270

- b: array-like, shape (k,), optional

271

Barycenter weights (optimized if None).

272

- weights: array-like, optional

273

Weights for combining input measures.

274

- numItermax: int, default=100

275

- stopThr: float, default=1e-7

276

- verbose: bool, default=False

277

- log: bool, default=False

278

279

Returns:

280

- X: ndarray, shape (k, d)

281

Optimal barycenter support points.

282

- b: ndarray, shape (k,)

283

Optimal barycenter weights.

284

- log: dict (if log=True)

285

"""

286

287

def ot.bregman.jcpot_barycenter(Xs, Ys, Ps, lambdas, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs):

288

"""

289

Compute Joint Characteristic-Optimal-Transport (JCPOT) barycenter.

290

291

Specialized barycenter for joint distributions, commonly used in

292

domain adaptation scenarios.

293

294

Parameters:

295

- Xs: list of arrays

296

Source feature matrices for each domain.

297

- Ys: list of arrays

298

Target feature matrices for each domain.

299

- Ps: list of arrays

300

Initial transport plans for each domain.

301

- lambdas: array-like

302

Weights for domain combination.

303

- reg: float

304

Entropic regularization.

305

- metric: str, default='sqeuclidean'

306

Ground metric for cost computation.

307

- numItermax: int, default=100

308

- stopThr: float, default=1e-6

309

- verbose: bool, default=False

310

- log: bool, default=False

311

312

Returns:

313

- X_barycenter: ndarray

314

Barycenter in source space.

315

- Y_barycenter: ndarray

316

Barycenter in target space.

317

- log: dict (if log=True)

318

"""

319

```

320

321

## Convolutional Barycenters

322

323

```python { .api }

324

def ot.bregman.convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):

325

"""

326

Compute 2D convolutional Wasserstein barycenter.

327

328

Specialized algorithm for 2D images using convolutional structure for

329

efficiency. Exploits translation invariance of the ground metric.

330

331

Parameters:

332

- A: array-like, shape (h, w, n_images)

333

Stack of 2D images/distributions.

334

- reg: float

335

Entropic regularization parameter.

336

- weights: array-like, shape (n_images,), optional

337

Barycenter weights.

338

- numItermax: int, default=10000

339

- stopThr: float, default=1e-9

340

- stabThr: float, default=1e-30

341

Numerical stability threshold.

342

- verbose: bool, default=False

343

- log: bool, default=False

344

345

Returns:

346

- barycenter: ndarray, shape (h, w)

347

2D convolutional Wasserstein barycenter.

348

- log: dict (if log=True)

349

"""

350

351

def ot.bregman.convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):

352

"""

353

Compute debiased 2D convolutional Wasserstein barycenter.

354

355

Applies debiasing to reduce regularization bias in convolutional barycenters.

356

357

Parameters: Same as convolutional_barycenter2d()

358

359

Returns:

360

- barycenter: ndarray, shape (h, w)

361

- log: dict (if log=True)

362

"""

363

```

364

365

## Empirical Methods

366

367

```python { .api }

368

def ot.bregman.empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):

369

"""

370

Compute Sinkhorn transport between empirical distributions.

371

372

Convenient wrapper that computes cost matrix from sample coordinates

373

and applies Sinkhorn algorithm.

374

375

Parameters:

376

- X_s: array-like, shape (n_samples_source, n_features)

377

Source samples.

378

- X_t: array-like, shape (n_samples_target, n_features)

379

Target samples.

380

- reg: float

381

Entropic regularization parameter.

382

- a: array-like, shape (n_samples_source,), optional

383

Source sample weights. Default is uniform.

384

- b: array-like, shape (n_samples_target,), optional

385

Target sample weights. Default is uniform.

386

- metric: str, default='sqeuclidean'

387

Ground metric for cost matrix computation.

388

- numItermax: int, default=10000

389

- stopThr: float, default=1e-9

390

- verbose: bool, default=False

391

- log: bool, default=False

392

393

Returns:

394

- transport_plan: ndarray, shape (n_samples_source, n_samples_target)

395

- log: dict (if log=True)

396

"""

397

398

def ot.bregman.empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):

399

"""

400

Compute empirical Sinkhorn transport cost only.

401

402

Parameters: Same as empirical_sinkhorn()

403

404

Returns:

405

- cost: float

406

Empirical Sinkhorn transport cost.

407

- log: dict (if log=True)

408

"""

409

410

def ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):

411

"""

412

Compute empirical Sinkhorn divergence (debiased).

413

414

Computes the Sinkhorn divergence: W_reg(a,b) - 0.5*W_reg(a,a) - 0.5*W_reg(b,b)

415

which removes the regularization bias for better approximation of Wasserstein distance.

416

417

Parameters: Same as empirical_sinkhorn()

418

419

Returns:

420

- divergence: float

421

Sinkhorn divergence value.

422

- log: dict (if log=True)

423

"""

424

425

def ot.bregman.empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False):

426

"""

427

GeomLoss-compatible implementation of empirical Sinkhorn.

428

429

Parameters: Same as empirical_sinkhorn()

430

431

Returns:

432

- cost: float

433

- log: dict (if log=True)

434

"""

435

436

def ot.bregman.geomloss(X_s, X_t, a=None, b=None, loss='sinkhorn', p=2, blur=0.1, reach=1.0, diameter=1.0, scaling=0.5, truncate=5, cost=None, kernel=None, cluster_scale=None, debias=True, potentials=False, verbose=False, backend='auto'):

437

"""

438

GeomLoss wrapper for various optimal transport losses.

439

440

Unified interface supporting Wasserstein, energy, Hausdorff and Sinkhorn losses

441

with automatic differentiation support.

442

443

Parameters:

444

- X_s: array-like, shape (n_s, d)

445

Source samples.

446

- X_t: array-like, shape (n_t, d)

447

Target samples.

448

- a: array-like, shape (n_s,), optional

449

Source weights.

450

- b: array-like, shape (n_t,), optional

451

Target weights.

452

- loss: str, default='sinkhorn'

453

Loss type: 'wasserstein', 'sinkhorn', 'energy', 'hausdorff'

454

- p: int, default=2

455

Ground metric exponent.

456

- blur: float, default=0.1

457

Regularization/smoothing parameter.

458

- reach: float, default=1.0

459

Kernel reach parameter.

460

- diameter: float, default=1.0

461

Point cloud diameter estimate.

462

- scaling: float, default=0.5

463

Multi-scale algorithm parameter.

464

- truncate: int, default=5

465

Kernel truncation parameter.

466

- cost: callable, optional

467

Custom cost function.

468

- kernel: callable, optional

469

Custom kernel function.

470

- cluster_scale: float, optional

471

Clustering scale for acceleration.

472

- debias: bool, default=True

473

Apply debiasing for Sinkhorn loss.

474

- potentials: bool, default=False

475

Return dual potentials.

476

- verbose: bool, default=False

477

- backend: str, default='auto'

478

Computation backend.

479

480

Returns:

481

- loss_value: float

482

Computed loss value.

483

- potentials: tuple (if potentials=True)

484

Dual potentials (f, g).

485

"""

486

```

487

488

## Utility Functions

489

490

```python { .api }

491

def ot.bregman.geometricBar(weights, alldistribT):

492

"""

493

Compute geometric barycenter in Bregman divergence sense.

494

495

Parameters:

496

- weights: array-like

497

Barycenter combination weights.

498

- alldistribT: array-like

499

Matrix of input distributions (columns).

500

501

Returns:

502

- barycenter: ndarray

503

Geometric barycenter.

504

"""

505

506

def ot.bregman.geometricMean(alldistribT):

507

"""

508

Compute geometric mean of distributions.

509

510

Parameters:

511

- alldistribT: array-like

512

Matrix of distributions (columns).

513

514

Returns:

515

- geometric_mean: ndarray

516

"""

517

518

def ot.bregman.projR(gamma, p):

519

"""

520

Project transport matrix onto row constraints.

521

522

Parameters:

523

- gamma: array-like

524

Transport matrix.

525

- p: array-like

526

Row marginal constraints.

527

528

Returns:

529

- projected_gamma: ndarray

530

"""

531

532

def ot.bregman.projC(gamma, q):

533

"""

534

Project transport matrix onto column constraints.

535

536

Parameters:

537

- gamma: array-like

538

Transport matrix.

539

- q: array-like

540

Column marginal constraints.

541

542

Returns:

543

- projected_gamma: ndarray

544

"""

545

546

def ot.bregman.unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):

547

"""

548

Solve optimal transport unmixing problem with regularization.

549

550

Decompose a distribution as a convex combination of dictionary atoms

551

using optimal transport as the fidelity term.

552

553

Parameters:

554

- a: array-like

555

Distribution to unmix.

556

- D: array-like

557

Dictionary of atoms (columns).

558

- M: array-like

559

Cost matrix for transport.

560

- M0: array-like

561

Cost matrix for dictionary regularization.

562

- h0: array-like

563

Prior on dictionary coefficients.

564

- reg: float

565

Transport regularization.

566

- reg0: float

567

Dictionary regularization.

568

- alpha: float

569

Fidelity vs regularization trade-off.

570

- numItermax: int, default=1000

571

- stopThr: float, default=1e-3

572

- verbose: bool, default=False

573

- log: bool, default=False

574

575

Returns:

576

- h: ndarray

577

Dictionary coefficients.

578

- log: dict (if log=True)

579

"""

580

```

581

582

## Usage Examples

583

584

### Basic Sinkhorn Algorithm

585

```python

586

import ot

587

import numpy as np

588

589

# Define distributions

590

a = np.array([0.5, 0.5])

591

b = np.array([0.3, 0.7])

592

593

# Cost matrix

594

M = np.array([[0.0, 1.0],

595

[1.0, 0.0]])

596

597

# Regularization parameter

598

reg = 0.1

599

600

# Compute regularized transport

601

plan_sinkhorn = ot.bregman.sinkhorn(a, b, M, reg)

602

cost_sinkhorn = ot.bregman.sinkhorn2(a, b, M, reg)

603

604

print("Sinkhorn plan:", plan_sinkhorn)

605

print("Sinkhorn cost:", cost_sinkhorn)

606

```

607

608

### Barycenter Computation

609

```python

610

# Multiple distributions

611

A = np.array([[0.6, 0.2, 0.4],

612

[0.4, 0.8, 0.6]]) # 3 distributions

613

614

# Cost matrix

615

M = ot.dist(np.arange(2).reshape(-1, 1))

616

617

# Regularization

618

reg = 0.05

619

620

# Compute barycenter

621

barycenter = ot.bregman.barycenter(A, M, reg)

622

print("Barycenter:", barycenter)

623

624

# With custom weights

625

weights = np.array([0.5, 0.3, 0.2])

626

weighted_barycenter = ot.bregman.barycenter(A, M, reg, weights=weights)

627

print("Weighted barycenter:", weighted_barycenter)

628

```

629

630

### Empirical Sinkhorn

631

```python

632

# Generate sample data

633

np.random.seed(42)

634

X_s = np.random.randn(100, 2)

635

X_t = np.random.randn(80, 2) + 1

636

637

# Regularization

638

reg = 0.1

639

640

# Compute empirical transport

641

plan = ot.bregman.empirical_sinkhorn(X_s, X_t, reg)

642

cost = ot.bregman.empirical_sinkhorn2(X_s, X_t, reg)

643

644

print("Empirical transport cost:", cost)

645

print("Transport plan shape:", plan.shape)

646

647

# Sinkhorn divergence (debiased)

648

divergence = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, reg)

649

print("Sinkhorn divergence:", divergence)

650

```

651

652

### Stabilized Sinkhorn for Small Regularization

653

```python

654

# Small regularization parameter

655

reg_small = 1e-3

656

657

# Use stabilized version to avoid numerical issues

658

plan_stable = ot.bregman.sinkhorn_stabilized(a, b, M, reg_small, verbose=True)

659

print("Stabilized Sinkhorn plan:", plan_stable)

660

661

# Or use epsilon scaling

662

plan_eps = ot.bregman.sinkhorn_epsilon_scaling(a, b, M, reg_small, verbose=True)

663

print("Epsilon-scaled plan:", plan_eps)

664

```

665

666

The `ot.bregman` module provides the most widely used algorithms in computational optimal transport, offering a good balance between computational efficiency and solution quality through entropic regularization. The Sinkhorn algorithm and its variants are particularly popular for large-scale applications and differentiable optimal transport.