docs
0
# Smooth Optimal Transport
1
2
Smooth optimal transport methods with dual and semi-dual formulations for regularized problems. These solvers support various regularization schemes including KL divergence and L2 regularization for sparse and smooth transport solutions.
3
4
## Capabilities
5
6
### Dual Formulation Solvers
7
8
Solve optimal transport using dual formulation with various regularization approaches.
9
10
```python { .api }
11
def smooth_ot_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
12
"""
13
Solve smooth optimal transport using dual formulation.
14
15
Parameters:
16
- a: array-like, source distribution
17
- b: array-like, target distribution
18
- C: array-like, cost matrix
19
- regul: Regularization, regularization instance (NegEntropy, SquaredL2, SparsityConstrained)
20
- method: str, optimization method ('L-BFGS-B', 'SLSQP', etc.)
21
- numItermax: int, maximum number of iterations
22
- log: bool, return optimization log
23
24
Returns:
25
- transport plan matrix or (plan, log) if log=True
26
"""
27
28
def solve_dual(a, b, C, regul, method='L-BFGS-B', tol=1e-3, max_iter=500, verbose=False):
29
"""
30
Solve dual optimal transport problem.
31
32
Parameters:
33
- a, b: array-like, source and target distributions
34
- C: array-like, cost matrix
35
- regul: Regularization, regularization instance
36
- method: str, scipy optimization method
37
- tol: float, optimization tolerance
38
- max_iter: int, maximum iterations
39
- verbose: bool, print optimization info
40
41
Returns:
42
- alpha, beta: dual variables
43
"""
44
```
45
46
### Semi-Dual Formulation Solvers
47
48
Solve optimal transport using semi-dual formulation for efficiency in certain scenarios.
49
50
```python { .api }
51
def smooth_ot_semi_dual(a, b, C, regul, method='L-BFGS-B', numItermax=500, log=False, **kwargs):
52
"""
53
Solve smooth optimal transport using semi-dual formulation.
54
55
Parameters:
56
- a: array-like, source distribution
57
- b: array-like, target distribution
58
- C: array-like, cost matrix
59
- regul: Regularization, regularization instance
60
- method: str, optimization method
61
- numItermax: int, maximum iterations
62
- log: bool, return optimization log
63
64
Returns:
65
- transport plan matrix or (plan, log) if log=True
66
"""
67
68
def solve_semi_dual(a, b, C, regul, method='L-BFGS-B', tol=1e-3, max_iter=500, verbose=False):
69
"""
70
Solve semi-dual optimal transport problem.
71
72
Parameters:
73
- a, b: array-like, source and target distributions
74
- C: array-like, cost matrix
75
- regul: Regularization, regularization instance
76
- method: str, scipy optimization method
77
- tol: float, optimization tolerance
78
- max_iter: int, maximum iterations
79
- verbose: bool, print optimization info
80
81
Returns:
82
- alpha: dual variable for source
83
"""
84
```
85
86
### Plan Recovery
87
88
Recover transport plans from dual variables computed by dual solvers.
89
90
```python { .api }
91
def get_plan_from_dual(alpha, beta, C, regul):
92
"""
93
Recover transport plan from dual variables.
94
95
Parameters:
96
- alpha: array-like, dual variable for source
97
- beta: array-like, dual variable for target
98
- C: array-like, cost matrix
99
- regul: Regularization, regularization instance
100
101
Returns:
102
- transport plan matrix
103
"""
104
105
def get_plan_from_semi_dual(alpha, b, C, regul):
106
"""
107
Recover transport plan from semi-dual variables.
108
109
Parameters:
110
- alpha: array-like, dual variable for source
111
- b: array-like, target distribution
112
- C: array-like, cost matrix
113
- regul: Regularization, regularization instance
114
115
Returns:
116
- transport plan matrix
117
"""
118
```
119
120
### Utility Functions
121
122
Support functions for smooth optimal transport including simplex projections.
123
124
```python { .api }
125
def projection_simplex(V, z=1, axis=None):
126
"""
127
Project V onto the simplex scaled by z.
128
129
Parameters:
130
- V: array-like, input array to project
131
- z: float or array, scaling factor
132
- axis: None or int, projection axis (None: flatten, 1: row-wise, 0: column-wise)
133
134
Returns:
135
- projected array with same shape as V
136
"""
137
```
138
139
## Regularization Classes
140
141
Base and concrete regularization classes for different smooth optimal transport formulations.
142
143
```python { .api }
144
class Regularization:
145
"""
146
Base class for regularization in smooth optimal transport.
147
148
Methods:
149
- delta_Omega(T): compute regularization function value
150
- max_Omega(T): compute maximum regularization over transport plans
151
"""
152
153
class NegEntropy(Regularization):
154
"""
155
Negative entropy regularization (KL divergence).
156
157
Parameters:
158
- gamma: float, regularization strength
159
"""
160
161
class SquaredL2(Regularization):
162
"""
163
Squared L2 norm regularization.
164
165
Parameters:
166
- gamma: float, regularization strength
167
"""
168
169
class SparsityConstrained(Regularization):
170
"""
171
Sparsity-constrained regularization for sparse transport plans.
172
173
Parameters:
174
- max_nz: int, maximum number of non-zero entries
175
"""
176
```
177
178
## Usage Example
179
180
```python
181
import ot
182
import ot.smooth
183
import numpy as np
184
185
# Create distributions and cost matrix
186
n, m = 100, 80
187
a = ot.utils.unif(n)
188
b = ot.utils.unif(m)
189
X = np.random.randn(n, 2)
190
Y = np.random.randn(m, 2)
191
C = ot.dist(X, Y)
192
193
# Solve with negative entropy regularization
194
regul = ot.smooth.NegEntropy(gamma=0.1)
195
plan_dual = ot.smooth.smooth_ot_dual(a, b, C, regul)
196
plan_semi_dual = ot.smooth.smooth_ot_semi_dual(a, b, C, regul)
197
198
# Solve with L2 regularization
199
regul_l2 = ot.smooth.SquaredL2(gamma=0.01)
200
plan_l2 = ot.smooth.smooth_ot_dual(a, b, C, regul_l2)
201
202
# Sparsity-constrained transport
203
regul_sparse = ot.smooth.SparsityConstrained(max_nz=500)
204
plan_sparse = ot.smooth.smooth_ot_dual(a, b, C, regul_sparse)
205
```
206
207
## Import Statements
208
209
```python
210
import ot.smooth
211
from ot.smooth import smooth_ot_dual, smooth_ot_semi_dual, NegEntropy, SquaredL2
212
```