docs
0
# Weak Optimal Transport
1
2
Weak optimal transport provides a relaxed formulation of the classical optimal transport problem where the transport plan minimizes displacement variance rather than total transport cost. This approach is particularly useful for applications where preserving local structure is more important than minimizing global transport costs.
3
4
## Capabilities
5
6
### Weak Optimal Transport Solver
7
8
Solve the weak optimal transport problem between empirical distributions.
9
10
```python { .api }
11
def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
12
"""
13
Solve weak optimal transport problem between two empirical distributions.
14
15
The weak OT problem minimizes the displacement variance:
16
γ = argmin_γ Σ_i a_i (X^a_i - (1/a_i) Σ_j γ_ij X^b_j)²
17
18
subject to standard transport constraints:
19
- γ 1 = a (source marginal constraint)
20
- γ^T 1 = b (target marginal constraint)
21
- γ ≥ 0 (non-negativity)
22
23
Parameters:
24
- Xa: array-like, shape (n_samples_a, n_features), source samples
25
- Xb: array-like, shape (n_samples_b, n_features), target samples
26
- a: array-like, shape (n_samples_a,), source distribution (uniform if None)
27
- b: array-like, shape (n_samples_b,), target distribution (uniform if None)
28
- verbose: bool, print optimization information
29
- log: bool, return optimization log
30
- G0: array-like, initial transport plan (None for uniform initialization)
31
32
Returns:
33
- transport plan matrix or (plan, log) if log=True
34
"""
35
```
36
37
## Theory and Applications
38
39
### Weak vs Classical Optimal Transport
40
41
**Classical Optimal Transport:**
42
- Minimizes total transport cost: `Σ_ij γ_ij C_ij`
43
- Optimal for minimizing global displacement
44
- Can create large local distortions
45
46
**Weak Optimal Transport:**
47
- Minimizes displacement variance: `Σ_i a_i ||X^a_i - barycenter_i||²`
48
- Preserves local neighborhood structure
49
- Better for shape matching and morphing applications
50
51
### Key Properties
52
53
1. **Local Structure Preservation**: Maintains local relationships in source space
54
2. **Barycentric Transport**: Each source point maps to a barycenter of target points
55
3. **Variance Minimization**: Reduces spread of transported mass around barycenters
56
4. **Conditional Gradient**: Efficiently solved using Frank-Wolfe type algorithms
57
58
## Usage Examples
59
60
### Basic Weak Transport
61
62
```python
63
import ot
64
import numpy as np
65
import matplotlib.pyplot as plt
66
67
# Create 2D point clouds
68
n_source, n_target = 100, 120
69
np.random.seed(42)
70
71
# Source: circle
72
theta_s = np.linspace(0, 2*np.pi, n_source)
73
Xa = np.column_stack([np.cos(theta_s), np.sin(theta_s)])
74
Xa += 0.1 * np.random.randn(n_source, 2) # Add noise
75
76
# Target: ellipse
77
theta_t = np.linspace(0, 2*np.pi, n_target)
78
Xb = np.column_stack([2*np.cos(theta_t), 0.5*np.sin(theta_t)])
79
Xb += 0.1 * np.random.randn(n_target, 2)
80
81
# Solve weak optimal transport
82
plan_weak = ot.weak_optimal_transport(Xa, Xb, verbose=True, log=False)
83
84
print(f"Transport plan shape: {plan_weak.shape}")
85
print(f"Plan sum: {np.sum(plan_weak):.6f}")
86
print(f"Source marginal error: {np.max(np.abs(np.sum(plan_weak, axis=1) - 1/n_source)):.6f}")
87
```
88
89
### Comparison with Classical Transport
90
91
```python
92
# Compare weak vs classical optimal transport
93
a = ot.utils.unif(n_source)
94
b = ot.utils.unif(n_target)
95
96
# Classical transport
97
M = ot.dist(Xa, Xb)
98
plan_classical = ot.emd(a, b, M)
99
100
# Weak transport
101
plan_weak = ot.weak_optimal_transport(Xa, Xb, a, b)
102
103
# Visualize differences
104
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
105
106
# Source and target
107
axes[0].scatter(Xa[:, 0], Xa[:, 1], c='blue', alpha=0.6, label='Source')
108
axes[0].scatter(Xb[:, 0], Xb[:, 1], c='red', alpha=0.6, label='Target')
109
axes[0].set_title('Source and Target')
110
axes[0].legend()
111
112
# Classical transport visualization
113
for i in range(0, n_source, 5): # Show subset of connections
114
for j in range(n_target):
115
if plan_classical[i, j] > 0.01:
116
axes[1].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]],
117
'k-', alpha=plan_classical[i, j]*10, linewidth=0.5)
118
axes[1].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)
119
axes[1].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)
120
axes[1].set_title('Classical OT')
121
122
# Weak transport visualization
123
for i in range(0, n_source, 5):
124
for j in range(n_target):
125
if plan_weak[i, j] > 0.01:
126
axes[2].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]],
127
'g-', alpha=plan_weak[i, j]*10, linewidth=0.5)
128
axes[2].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)
129
axes[2].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)
130
axes[2].set_title('Weak OT')
131
132
plt.tight_layout()
133
plt.show()
134
```
135
136
### Shape Morphing Application
137
138
```python
139
# Use weak transport for shape morphing
140
def interpolate_shapes(Xa, Xb, t=0.5):
141
"""Interpolate between shapes using weak transport."""
142
plan = ot.weak_optimal_transport(Xa, Xb)
143
144
# Compute barycenters for each source point
145
barycenters = np.zeros_like(Xa)
146
for i in range(len(Xa)):
147
if np.sum(plan[i, :]) > 0:
148
weights = plan[i, :] / np.sum(plan[i, :])
149
barycenters[i] = np.average(Xb, weights=weights, axis=0)
150
else:
151
barycenters[i] = Xa[i] # No transport for this point
152
153
# Linear interpolation
154
interpolated = (1 - t) * Xa + t * barycenters
155
return interpolated
156
157
# Create morphing sequence
158
n_steps = 10
159
morphing_sequence = []
160
for i in range(n_steps + 1):
161
t = i / n_steps
162
shape_t = interpolate_shapes(Xa, Xb, t)
163
morphing_sequence.append(shape_t)
164
165
# Visualize morphing
166
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
167
axes = axes.flatten()
168
169
for i, shape in enumerate(morphing_sequence[::1]): # Show every shape
170
if i < len(axes):
171
axes[i].scatter(shape[:, 0], shape[:, 1], c='purple', alpha=0.7, s=20)
172
axes[i].set_title(f't = {i/(len(morphing_sequence)-1):.1f}')
173
axes[i].set_aspect('equal')
174
axes[i].grid(True, alpha=0.3)
175
176
plt.tight_layout()
177
plt.show()
178
```
179
180
### Advanced Usage with Custom Parameters
181
182
```python
183
# Advanced usage with custom initialization and logging
184
import time
185
186
# Custom initial transport plan (e.g., based on nearest neighbors)
187
from sklearn.neighbors import NearestNeighbors
188
nn = NearestNeighbors(n_neighbors=3)
189
nn.fit(Xb)
190
distances, indices = nn.kneighbors(Xa)
191
192
# Create sparse initialization
193
G0 = np.zeros((n_source, n_target))
194
for i in range(n_source):
195
for j, idx in enumerate(indices[i]):
196
G0[i, idx] = 1.0 / len(indices[i])
197
198
# Solve with custom initialization and detailed logging
199
start_time = time.time()
200
plan, log = ot.weak_optimal_transport(
201
Xa, Xb,
202
a=ot.utils.unif(n_source),
203
b=ot.utils.unif(n_target),
204
G0=G0,
205
verbose=True,
206
log=True,
207
numItermax=1000,
208
stopThr=1e-9
209
)
210
solve_time = time.time() - start_time
211
212
print(f"Solver completed in {solve_time:.3f} seconds")
213
print(f"Final objective: {log['loss'][-1]:.6f}")
214
print(f"Number of iterations: {len(log['loss'])}")
215
216
# Plot convergence
217
plt.figure(figsize=(10, 6))
218
plt.subplot(1, 2, 1)
219
plt.semilogy(log['loss'])
220
plt.xlabel('Iteration')
221
plt.ylabel('Objective value')
222
plt.title('Convergence of Weak OT')
223
plt.grid(True)
224
225
plt.subplot(1, 2, 2)
226
plt.imshow(plan, cmap='Blues', aspect='auto')
227
plt.colorbar()
228
plt.xlabel('Target samples')
229
plt.ylabel('Source samples')
230
plt.title('Transport Plan')
231
plt.tight_layout()
232
plt.show()
233
```
234
235
## Import Statements
236
237
```python
238
import ot
239
from ot import weak_optimal_transport
240
from ot.weak import weak_optimal_transport
241
```