docs
0
# Factored Optimal Transport
1
2
Factored optimal transport provides efficient algorithms for problems with special structure that allows factorization of the transport plan. This approach significantly reduces computational complexity for large-scale problems with structured data.
3
4
## Capabilities
5
6
### Factored Optimal Transport Solver
7
8
Solve optimal transport problems using factored decomposition approaches.
9
10
```python { .api }
11
def factored_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, **kwargs):
12
"""
13
Solve optimal transport using factored decomposition.
14
15
This method exploits structure in the data to factorize the transport plan,
16
reducing computational complexity from O(n²) to approximately O(n·k) where
17
k << n is the factorization rank.
18
19
Parameters:
20
- Xa: array-like, shape (n_samples_a, n_features), source samples
21
- Xb: array-like, shape (n_samples_b, n_features), target samples
22
- a: array-like, shape (n_samples_a,), source distribution (uniform if None)
23
- b: array-like, shape (n_samples_b,), target distribution (uniform if None)
24
- verbose: bool, print optimization information
25
- log: bool, return optimization log and factorization details
26
27
Returns:
28
- transport plan matrix or (plan, log) if log=True
29
"""
30
```
31
32
## Factorization Approaches
33
34
### Low-Rank Transport Plans
35
36
Many optimal transport problems exhibit low-rank structure that can be exploited:
37
38
**Standard Transport Plan**: `γ ∈ R^(n×m)` with O(nm) complexity
39
**Factored Transport Plan**: `γ ≈ UV^T` where `U ∈ R^(n×k)`, `V ∈ R^(m×k)` with O((n+m)k) complexity
40
41
### Structured Data Scenarios
42
43
Factored transport is particularly effective for:
44
45
1. **Gaussian Distributions**: Natural low-rank structure in transport plans
46
2. **Time Series**: Temporal structure enables efficient factorization
47
3. **Images**: Spatial correlation allows patch-based factorization
48
4. **Graph Data**: Community structure supports block-wise transport
49
5. **High-dimensional Data**: Manifold structure enables dimensionality reduction
50
51
## Usage Examples
52
53
### Basic Factored Transport
54
55
```python
56
import ot
57
import numpy as np
58
import matplotlib.pyplot as plt
59
60
# Create high-dimensional data with low-rank structure
61
n_source, n_target = 1000, 1200
62
n_features = 50
63
rank = 5
64
65
# Generate low-rank source data
66
U_source = np.random.randn(n_source, rank)
67
V_source = np.random.randn(rank, n_features)
68
Xa = U_source @ V_source + 0.1 * np.random.randn(n_source, n_features)
69
70
# Generate related target data
71
U_target = U_source + 0.5 * np.random.randn(n_source, rank)
72
U_target = np.vstack([U_target, np.random.randn(n_target - n_source, rank)])
73
V_target = V_source + 0.3 * np.random.randn(rank, n_features)
74
Xb = U_target @ V_target + 0.1 * np.random.randn(n_target, n_features)
75
76
# Solve using factored transport
77
plan_factored = ot.factored_optimal_transport(Xa, Xb, verbose=True, log=False)
78
79
print(f"Transport plan shape: {plan_factored.shape}")
80
print(f"Plan sparsity: {np.sum(plan_factored > 1e-8) / plan_factored.size:.4f}")
81
```
82
83
### Comparison with Standard Methods
84
85
```python
86
# Compare computational efficiency
87
import time
88
89
# Standard optimal transport (for smaller problem)
90
n_small = 200
91
Xa_small = Xa[:n_small]
92
Xb_small = Xb[:n_small]
93
94
a_small = ot.utils.unif(n_small)
95
b_small = ot.utils.unif(n_small)
96
97
# Standard EMD
98
start_time = time.time()
99
M_small = ot.dist(Xa_small, Xb_small)
100
plan_emd = ot.emd(a_small, b_small, M_small)
101
time_emd = time.time() - start_time
102
103
# Sinkhorn
104
start_time = time.time()
105
plan_sinkhorn = ot.sinkhorn(a_small, b_small, M_small, reg=0.1)
106
time_sinkhorn = time.time() - start_time
107
108
# Factored transport (same problem)
109
start_time = time.time()
110
plan_factored_small = ot.factored_optimal_transport(Xa_small, Xb_small)
111
time_factored = time.time() - start_time
112
113
print(f"Timing comparison (n={n_small}):")
114
print(f" EMD: {time_emd:.4f}s")
115
print(f" Sinkhorn: {time_sinkhorn:.4f}s")
116
print(f" Factored: {time_factored:.4f}s")
117
118
# Large-scale problem (only factored transport feasible)
119
print(f"\\nLarge-scale problem (n_source={n_source}, n_target={n_target}):")
120
start_time = time.time()
121
plan_large = ot.factored_optimal_transport(Xa, Xb, verbose=False)
122
time_large = time.time() - start_time
123
print(f" Factored transport: {time_large:.4f}s")
124
```
125
126
### Gaussian Mixture Example
127
128
```python
129
# Example with Gaussian mixtures (natural factorization)
130
from sklearn.mixture import GaussianMixture
131
132
# Create Gaussian mixture data
133
n_components = 3
134
n_samples_per_comp = 300
135
136
# Source mixture
137
gmm_source = GaussianMixture(n_components=n_components, random_state=42)
138
Xa_gmm = np.vstack([
139
np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], n_samples_per_comp),
140
np.random.multivariate_normal([3, 0], [[1, -0.3], [-0.3, 1]], n_samples_per_comp),
141
np.random.multivariate_normal([1.5, 3], [[0.8, 0.2], [0.2, 0.8]], n_samples_per_comp)
142
])
143
144
# Target mixture (shifted and rotated)
145
theta = np.pi / 6
146
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
147
Xb_gmm = np.vstack([
148
np.random.multivariate_normal([1, 1], [[1.2, 0.4], [0.4, 1.2]], n_samples_per_comp),
149
np.random.multivariate_normal([4, 1], [[1, -0.4], [-0.4, 1]], n_samples_per_comp),
150
np.random.multivariate_normal([2.5, 4], [[0.9, 0.3], [0.3, 0.9]], n_samples_per_comp)
151
]) @ R.T
152
153
# Solve with factored transport
154
plan_gmm, log_gmm = ot.factored_optimal_transport(
155
Xa_gmm, Xb_gmm,
156
verbose=True,
157
log=True
158
)
159
160
# Visualize results
161
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
162
163
# Source data
164
axes[0].scatter(Xa_gmm[:, 0], Xa_gmm[:, 1], alpha=0.6, c='blue')
165
axes[0].set_title('Source Distribution')
166
axes[0].set_aspect('equal')
167
168
# Target data
169
axes[1].scatter(Xb_gmm[:, 0], Xb_gmm[:, 1], alpha=0.6, c='red')
170
axes[1].set_title('Target Distribution')
171
axes[1].set_aspect('equal')
172
173
# Transport plan visualization
174
im = axes[2].imshow(plan_gmm, cmap='Blues', aspect='auto')
175
axes[2].set_xlabel('Target samples')
176
axes[2].set_ylabel('Source samples')
177
axes[2].set_title('Factored Transport Plan')
178
plt.colorbar(im, ax=axes[2])
179
180
plt.tight_layout()
181
plt.show()
182
183
if 'factorization_rank' in log_gmm:
184
print(f"Effective factorization rank: {log_gmm['factorization_rank']}")
185
```
186
187
### Time Series Transport
188
189
```python
190
# Example with time series data
191
from sklearn.decomposition import PCA
192
193
# Generate time series with shared temporal patterns
194
t = np.linspace(0, 10, 100)
195
n_series_source = 200
196
n_series_target = 250
197
198
# Base temporal patterns
199
patterns = np.array([
200
np.sin(t),
201
np.cos(t),
202
np.sin(2*t),
203
np.exp(-t/5) * np.sin(t)
204
]).T
205
206
# Source time series (linear combinations of patterns)
207
weights_source = np.random.exponential(1, (n_series_source, 4))
208
Xa_ts = weights_source @ patterns.T + 0.1 * np.random.randn(n_series_source, len(t))
209
210
# Target time series (shifted patterns)
211
weights_target = np.random.exponential(1.2, (n_series_target, 4))
212
patterns_shifted = np.roll(patterns, 5, axis=0) # Temporal shift
213
Xb_ts = weights_target @ patterns_shifted.T + 0.1 * np.random.randn(n_series_target, len(t))
214
215
# Apply PCA preprocessing to enhance structure
216
pca = PCA(n_components=10)
217
Xa_ts_pca = pca.fit_transform(Xa_ts)
218
Xb_ts_pca = pca.transform(Xb_ts)
219
220
# Factored transport on time series
221
plan_ts = ot.factored_optimal_transport(Xa_ts_pca, Xb_ts_pca, verbose=True)
222
223
# Visualize sample time series and their transport
224
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
225
226
# Sample source time series
227
for i in range(5):
228
axes[0, 0].plot(t, Xa_ts[i], alpha=0.7)
229
axes[0, 0].set_title('Sample Source Time Series')
230
axes[0, 0].set_xlabel('Time')
231
232
# Sample target time series
233
for i in range(5):
234
axes[0, 1].plot(t, Xb_ts[i], alpha=0.7)
235
axes[0, 1].set_title('Sample Target Time Series')
236
axes[0, 1].set_xlabel('Time')
237
238
# PCA representation
239
axes[1, 0].scatter(Xa_ts_pca[:, 0], Xa_ts_pca[:, 1], alpha=0.6, label='Source')
240
axes[1, 0].scatter(Xb_ts_pca[:, 0], Xb_ts_pca[:, 1], alpha=0.6, label='Target')
241
axes[1, 0].set_xlabel('PC1')
242
axes[1, 0].set_ylabel('PC2')
243
axes[1, 0].set_title('PCA Representation')
244
axes[1, 0].legend()
245
246
# Transport plan sparsity pattern
247
axes[1, 1].spy(plan_ts > 1e-6, markersize=0.1)
248
axes[1, 1].set_title('Transport Plan Sparsity')
249
axes[1, 1].set_xlabel('Target series')
250
axes[1, 1].set_ylabel('Source series')
251
252
plt.tight_layout()
253
plt.show()
254
```
255
256
## Import Statements
257
258
```python
259
import ot
260
from ot import factored_optimal_transport
261
from ot.factored import factored_optimal_transport
262
```