or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

unbalanced-transport.mddocs/

0

# Unbalanced Optimal Transport

1

2

The `ot.unbalanced` module provides algorithms for unbalanced optimal transport, where the marginal constraints are relaxed allowing different total masses between source and target distributions. This is particularly useful for applications involving data with outliers, noise, or when comparing distributions with naturally different masses.

3

4

## Core Unbalanced Methods

5

6

### Sinkhorn-based Unbalanced Transport

7

8

```python { .api }

9

def ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):

10

"""

11

Solve unbalanced optimal transport using Sinkhorn algorithm with KL relaxation.

12

13

Solves the unbalanced optimal transport problem:

14

min <P,M> + reg*KL(P|K) + reg_m*KL(P1|a) + reg_m*KL(P^T1|b)

15

where the marginal constraints are relaxed using KL divergences.

16

17

Parameters:

18

- a: array-like, shape (n_samples_source,)

19

Source distribution. Need not sum to 1.

20

- b: array-like, shape (n_samples_target,)

21

Target distribution. Need not sum to 1.

22

- M: array-like, shape (n_samples_source, n_samples_target)

23

Ground cost matrix.

24

- reg: float

25

Entropic regularization parameter (>0).

26

- reg_m: float or tuple of floats

27

Marginal relaxation parameter(s). If float, uses same value for both

28

marginals. If tuple (reg_m1, reg_m2), uses different values.

29

- method: str, default='sinkhorn'

30

Algorithm variant. Options: 'sinkhorn', 'sinkhorn_stabilized',

31

'sinkhorn_translation_invariant'

32

- numItermax: int, default=1000

33

Maximum number of iterations.

34

- stopThr: float, default=1e-6

35

Convergence threshold on marginal violation.

36

- verbose: bool, default=False

37

Print iteration information.

38

- log: bool, default=False

39

Return optimization log.

40

- warn: bool, default=True

41

Warn if algorithm doesn't converge.

42

43

Returns:

44

- transport_plan: ndarray, shape (n_samples_source, n_samples_target)

45

Unbalanced optimal transport plan.

46

- log: dict (if log=True)

47

Contains 'err': convergence errors, 'mass_source': final source mass,

48

'mass_target': final target mass, 'u': source scaling, 'v': target scaling.

49

"""

50

51

def ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):

52

"""

53

Solve unbalanced optimal transport and return cost only.

54

55

More efficient than sinkhorn_unbalanced() when only the cost is needed.

56

57

Parameters: Same as sinkhorn_unbalanced()

58

59

Returns:

60

- cost: float

61

Unbalanced optimal transport cost.

62

- log: dict (if log=True)

63

"""

64

65

def ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):

66

"""

67

Unbalanced Sinkhorn-Knopp algorithm with multiplicative updates.

68

69

Classic formulation using diagonal scaling matrices for unbalanced case.

70

71

Parameters: Same as sinkhorn_unbalanced()

72

73

Returns:

74

- transport_plan: ndarray

75

- log: dict (if log=True)

76

"""

77

78

def ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):

79

"""

80

Stabilized unbalanced Sinkhorn algorithm.

81

82

Uses tau-absorption technique to prevent numerical overflow while

83

handling unbalanced marginals.

84

85

Parameters:

86

- a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()

87

- tau: float, default=1e3

88

Absorption threshold for numerical stability.

89

- Other parameters same as sinkhorn_unbalanced()

90

91

Returns:

92

- transport_plan: ndarray

93

- log: dict (if log=True)

94

"""

95

96

def ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, c=None, rescale_plan=True, numItermax=1000, stopThr=1e-6, verbose=False, log=False):

97

"""

98

Translation-invariant unbalanced Sinkhorn algorithm.

99

100

Uses a translation-invariant formulation that can be more numerically

101

stable and allows for better initialization strategies.

102

103

Parameters:

104

- a, b, M, reg, reg_m: Same as sinkhorn_unbalanced()

105

- c: array-like, optional

106

Translation vector for numerical stability.

107

- rescale_plan: bool, default=True

108

Whether to rescale the final transport plan.

109

- Other parameters same as sinkhorn_unbalanced()

110

111

Returns:

112

- transport_plan: ndarray

113

- log: dict (if log=True)

114

"""

115

```

116

117

### Unbalanced Barycenters

118

119

```python { .api }

120

def ot.unbalanced.barycenter_unbalanced(A, M, reg, reg_m, weights=None, method='sinkhorn', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):

121

"""

122

Compute unbalanced Wasserstein barycenter.

123

124

Finds the barycenter that minimizes the sum of unbalanced transport costs

125

to all input distributions, allowing for mass creation/destruction.

126

127

Parameters:

128

- A: array-like, shape (n_samples, n_distributions)

129

Input distributions as columns. Need not be normalized.

130

- M: array-like, shape (n_samples, n_samples)

131

Ground cost matrix on barycenter support.

132

- reg: float

133

Entropic regularization parameter.

134

- reg_m: float

135

Marginal relaxation parameter.

136

- weights: array-like, shape (n_distributions,), optional

137

Weights for barycenter combination. Default is uniform.

138

- method: str, default='sinkhorn'

139

Algorithm variant for unbalanced transport computation.

140

- numItermax: int, default=1000

141

Maximum iterations for barycenter computation.

142

- stopThr: float, default=1e-6

143

Convergence threshold.

144

- verbose: bool, default=False

145

- log: bool, default=False

146

147

Returns:

148

- barycenter: ndarray, shape (n_samples,)

149

Unbalanced Wasserstein barycenter (may not sum to 1).

150

- log: dict (if log=True)

151

Contains convergence information and transport plans.

152

"""

153

154

def ot.unbalanced.barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):

155

"""

156

Compute unbalanced barycenter using Sinkhorn algorithm.

157

158

Alternative implementation with explicit Sinkhorn iterations.

159

160

Parameters: Same as barycenter_unbalanced()

161

162

Returns:

163

- barycenter: ndarray

164

- log: dict (if log=True)

165

"""

166

167

def ot.unbalanced.barycenter_unbalanced_stabilized(A, M, reg, reg_m, tau=1e3, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):

168

"""

169

Compute unbalanced barycenter using stabilized algorithm.

170

171

Parameters:

172

- A, M, reg, reg_m, weights: Same as barycenter_unbalanced()

173

- tau: float, default=1e3

174

Stabilization parameter.

175

- Other parameters same as barycenter_unbalanced()

176

177

Returns:

178

- barycenter: ndarray

179

- log: dict (if log=True)

180

"""

181

```

182

183

## MM Algorithm

184

185

```python { .api }

186

def ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):

187

"""

188

Solve unbalanced optimal transport using MM (Majorization-Minimization) algorithm.

189

190

Alternative optimization approach that can handle different divergences

191

for marginal relaxation beyond KL divergence.

192

193

Parameters:

194

- a: array-like, shape (n_samples_source,)

195

Source distribution.

196

- b: array-like, shape (n_samples_target,)

197

Target distribution.

198

- M: array-like, shape (n_samples_source, n_samples_target)

199

Ground cost matrix.

200

- reg: float

201

Entropic regularization parameter.

202

- reg_m: float or tuple

203

Marginal relaxation parameter(s).

204

- div: str, default='kl'

205

Divergence for marginal relaxation. Options: 'kl', 'l2', 'tv'

206

- G0: array-like, optional

207

Initial transport plan.

208

- numItermax: int, default=1000

209

- stopThr: float, default=1e-6

210

- verbose: bool, default=False

211

- log: bool, default=False

212

213

Returns:

214

- transport_plan: ndarray, shape (n_samples_source, n_samples_target)

215

- log: dict (if log=True)

216

"""

217

218

def ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='kl', G0=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False):

219

"""

220

MM algorithm for unbalanced OT returning cost only.

221

222

Parameters: Same as mm_unbalanced()

223

224

Returns:

225

- cost: float

226

Unbalanced transport cost.

227

- log: dict (if log=True)

228

"""

229

```

230

231

## L-BFGS-B Methods

232

233

```python { .api }

234

def ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):

235

"""

236

Solve unbalanced optimal transport using L-BFGS-B optimizer.

237

238

Uses quasi-Newton optimization method for solving the dual formulation

239

of unbalanced optimal transport, which can be more efficient for

240

large-scale problems.

241

242

Parameters:

243

- a: array-like, shape (n_samples_source,)

244

Source distribution.

245

- b: array-like, shape (n_samples_target,)

246

Target distribution.

247

- M: array-like, shape (n_samples_source, n_samples_target)

248

Ground cost matrix.

249

- reg: float

250

Entropic regularization parameter.

251

- reg_m: float or tuple

252

Marginal relaxation parameter(s).

253

- c: array-like, optional

254

Translation vector for numerical stability.

255

- reg_div: str, default='kl'

256

Divergence type for marginal regularization.

257

- G0: array-like, optional

258

Initial transport plan.

259

- numItermax: int, default=1000

260

Maximum outer iterations.

261

- numInnerItermax: int, default=10

262

Maximum inner iterations for line search.

263

- stopThr: float, default=1e-6

264

Convergence threshold for outer loop.

265

- stopThr2: float, default=1e-6

266

Convergence threshold for inner loop.

267

- verbose: bool, default=False

268

- log: bool, default=False

269

270

Returns:

271

- transport_plan: ndarray, shape (n_samples_source, n_samples_target)

272

- log: dict (if log=True)

273

Contains optimization details including L-BFGS-B convergence info.

274

"""

275

276

def ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', G0=None, numItermax=1000, numInnerItermax=10, stopThr=1e-6, stopThr2=1e-6, verbose=False, log=False):

277

"""

278

L-BFGS-B unbalanced OT returning cost only.

279

280

Parameters: Same as lbfgsb_unbalanced()

281

282

Returns:

283

- cost: float

284

- log: dict (if log=True)

285

"""

286

```

287

288

## Regularization and Divergences

289

290

The unbalanced transport framework supports different types of marginal relaxation:

291

292

### KL Divergence Relaxation

293

The most common choice using Kullback-Leibler divergence for marginal penalties:

294

```

295

KL(π₁|a) = Σᵢ π₁(i) log(π₁(i)/a(i)) - π₁(i) + a(i)

296

```

297

298

### Alternative Divergences

299

- **L2 Penalty**: `div='l2'` - Quadratic penalty on marginal violations

300

- **Total Variation**: `div='tv'` - L1 penalty on marginal differences

301

- **Custom Divergences**: User-defined penalty functions

302

303

## Usage Examples

304

305

### Basic Unbalanced Transport

306

```python

307

import ot

308

import numpy as np

309

310

# Create unbalanced distributions

311

a = np.array([0.6, 0.4]) # Source (sums to 1.0)

312

b = np.array([0.2, 0.3, 0.1]) # Target (sums to 0.6)

313

314

# Cost matrix

315

M = np.random.rand(2, 3)

316

317

# Regularization parameters

318

reg = 0.1 # Entropic regularization

319

reg_m = 0.5 # Marginal relaxation

320

321

# Solve unbalanced transport

322

plan_unbalanced = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m, verbose=True)

323

cost_unbalanced = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, reg_m)

324

325

print("Unbalanced transport plan:")

326

print(plan_unbalanced)

327

print(f"Unbalanced cost: {cost_unbalanced}")

328

329

# Check mass conservation

330

source_mass = np.sum(plan_unbalanced, axis=1)

331

target_mass = np.sum(plan_unbalanced, axis=0)

332

print(f"Source masses: {source_mass} (original: {a})")

333

print(f"Target masses: {target_mass} (original: {b})")

334

```

335

336

### Different Marginal Regularizations

337

```python

338

# Different regularization for source and target

339

reg_m_source = 0.3

340

reg_m_target = 0.7

341

reg_m_tuple = (reg_m_source, reg_m_target)

342

343

plan_asym = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_tuple)

344

print("Asymmetric marginal regularization plan:")

345

print(plan_asym)

346

```

347

348

### Unbalanced Barycenter

349

```python

350

# Multiple unbalanced distributions

351

A = np.array([[0.6, 0.2, 0.4],

352

[0.4, 0.3, 0.6],

353

[0.0, 0.5, 0.0]]) # 3 distributions, different masses

354

355

# Cost matrix for barycenter space

356

M_bary = ot.dist(np.arange(3).reshape(-1, 1))

357

358

# Compute unbalanced barycenter

359

reg_bary = 0.05

360

reg_m_bary = 0.2

361

362

barycenter = ot.unbalanced.barycenter_unbalanced(A, M_bary, reg_bary, reg_m_bary, verbose=True)

363

364

print("Unbalanced barycenter:")

365

print(barycenter)

366

print(f"Barycenter mass: {np.sum(barycenter)}")

367

```

368

369

### MM Algorithm with Different Divergences

370

```python

371

# Use L2 divergence for marginal relaxation

372

plan_mm_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='l2')

373

cost_mm_l2 = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='l2')

374

375

print(f"MM L2 cost: {cost_mm_l2}")

376

377

# Use Total Variation divergence

378

plan_mm_tv = ot.unbalanced.mm_unbalanced(a, b, M, reg, reg_m, div='tv')

379

cost_mm_tv = ot.unbalanced.mm_unbalanced2(a, b, M, reg, reg_m, div='tv')

380

381

print(f"MM TV cost: {cost_mm_tv}")

382

```

383

384

### Empirical Unbalanced Transport

385

```python

386

# Generate unbalanced sample data

387

np.random.seed(42)

388

n_source, n_target = 100, 80

389

X_s = np.random.randn(n_source, 2)

390

X_t = np.random.randn(n_target, 2) + 1

391

392

# Unbalanced weights (don't sum to 1)

393

a_unbalanced = np.random.exponential(0.8, n_source)

394

b_unbalanced = np.random.exponential(1.2, n_target)

395

396

# Compute cost matrix

397

M_empirical = ot.dist(X_s, X_t)

398

399

# Solve unbalanced transport

400

reg_emp = 0.1

401

reg_m_emp = 0.3

402

403

plan_emp = ot.unbalanced.sinkhorn_unbalanced(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)

404

cost_emp = ot.unbalanced.sinkhorn_unbalanced2(a_unbalanced, b_unbalanced, M_empirical, reg_emp, reg_m_emp)

405

406

print(f"Empirical unbalanced cost: {cost_emp}")

407

print(f"Original source mass: {np.sum(a_unbalanced):.3f}")

408

print(f"Original target mass: {np.sum(b_unbalanced):.3f}")

409

print(f"Transported mass: {np.sum(plan_emp):.3f}")

410

```

411

412

### Stabilized Algorithm for Extreme Cases

413

```python

414

# Very small regularization or large costs

415

reg_small = 1e-4

416

M_large = M * 100

417

418

# Use stabilized version

419

plan_stable = ot.unbalanced.sinkhorn_stabilized_unbalanced(

420

a, b, M_large, reg_small, reg_m, tau=1e2, verbose=True

421

)

422

423

print("Stabilized unbalanced transport completed")

424

```

425

426

### L-BFGS-B for Large-Scale Problems

427

```python

428

# For larger problems, L-BFGS-B can be more efficient

429

n_large = 500

430

a_large = np.random.exponential(1.0, n_large)

431

b_large = np.random.exponential(1.5, n_large)

432

M_large = np.random.rand(n_large, n_large)

433

434

# Use L-BFGS-B solver

435

plan_lbfgs = ot.unbalanced.lbfgsb_unbalanced(

436

a_large, b_large, M_large, reg, reg_m,

437

numItermax=100, verbose=True

438

)

439

cost_lbfgs = ot.unbalanced.lbfgsb_unbalanced2(

440

a_large, b_large, M_large, reg, reg_m

441

)

442

443

print(f"L-BFGS-B unbalanced cost: {cost_lbfgs}")

444

```

445

446

## Applications

447

448

### Comparing Unnormalized Data

449

Unbalanced transport is particularly useful when:

450

- Comparing histograms or distributions that naturally have different total masses

451

- Handling data with missing values or outliers

452

- Robust matching in the presence of noise

453

- Domain adaptation with different sample sizes

454

455

### Mass Creation and Destruction

456

The relaxed marginal constraints allow:

457

- **Mass Creation**: Transport plan can have row/column sums exceeding the original marginals

458

- **Mass Destruction**: Transport plan can have row/column sums below the original marginals

459

- **Outlier Handling**: Points with no good matches can have reduced mass

460

461

### Computational Advantages

462

- More robust convergence than balanced transport

463

- Better numerical stability with extreme regularization parameters

464

- Natural handling of datasets with different cardinalities

465

466

The `ot.unbalanced` module provides essential tools for real-world optimal transport applications where perfect mass conservation is not required or desired, offering both theoretical flexibility and computational advantages.