docs
0
# Unified Solvers
1
2
High-level unified interface for optimal transport solvers, providing a consistent API across different problem types and algorithms. These solvers automatically select appropriate methods based on problem characteristics and user preferences.
3
4
## Capabilities
5
6
### General Optimal Transport Solver
7
8
Unified solver for standard optimal transport problems with automatic algorithm selection.
9
10
```python { .api }
11
def solve(a, b, M, reg=None, reg_type='entropy', method='auto', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
12
"""
13
General optimal transport solver with automatic method selection.
14
15
This function provides a unified interface to various OT solvers, automatically
16
selecting the most appropriate algorithm based on problem size, regularization,
17
and other parameters.
18
19
Parameters:
20
- a: array-like, source distribution
21
- b: array-like, target distribution
22
- M: array-like, cost matrix
23
- reg: float, regularization parameter (None for exact transport)
24
- reg_type: str, regularization type ('entropy', 'l2', 'kl', 'tv')
25
- method: str, solver method ('auto', 'emd', 'sinkhorn', 'sinkhorn_log',
26
'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling', 'smooth')
27
- numItermax: int, maximum number of iterations
28
- stopThr: float, convergence threshold
29
- verbose: bool, print solver information
30
- log: bool, return optimization log
31
32
Returns:
33
- transport plan matrix or (plan, log) if log=True
34
"""
35
```
36
37
### Gromov-Wasserstein Solver
38
39
Unified solver for Gromov-Wasserstein problems and variants.
40
41
```python { .api }
42
def solve_gromov(C1, C2, p=None, q=None, M=None, alpha=0.0, reg=None, reg_type='entropy', method='auto', loss_fun='square_loss', armijo=False, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
43
"""
44
General Gromov-Wasserstein solver with automatic method selection.
45
46
Solves Gromov-Wasserstein and Fused Gromov-Wasserstein problems using
47
appropriate algorithms based on regularization and problem characteristics.
48
49
Parameters:
50
- C1: array-like, cost matrix for source space
51
- C2: array-like, cost matrix for target space
52
- p: array-like, source distribution (uniform if None)
53
- q: array-like, target distribution (uniform if None)
54
- M: array-like, feature cost matrix (for Fused GW, None for pure GW)
55
- alpha: float, trade-off parameter between structure and features (0=pure GW, 1=pure Wasserstein)
56
- reg: float, regularization parameter (None for exact)
57
- reg_type: str, regularization type ('entropy', 'l2')
58
- method: str, solver method ('auto', 'conditional_gradient', 'proximal_point', 'frank_wolfe')
59
- loss_fun: str or callable, loss function ('square_loss', 'kl_loss')
60
- armijo: bool, use Armijo line search
61
- numItermax: int, maximum iterations
62
- stopThr: float, convergence threshold
63
- verbose: bool, print information
64
- log: bool, return optimization log
65
66
Returns:
67
- transport plan matrix or (plan, log) if log=True
68
"""
69
```
70
71
### Sampling-based Solver
72
73
Solver for large-scale problems using sampling approaches.
74
75
```python { .api }
76
def solve_sample(X_s, X_t, a=None, b=None, method='gromov_wasserstein_samples', reg=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):
77
"""
78
Solve optimal transport using sampling-based methods.
79
80
Efficient solver for large-scale problems using sampling techniques
81
to approximate optimal transport distances and plans.
82
83
Parameters:
84
- X_s: array-like, source samples (n_samples_s, n_features)
85
- X_t: array-like, target samples (n_samples_t, n_features)
86
- a: array-like, source weights (uniform if None)
87
- b: array-like, target weights (uniform if None)
88
- method: str, sampling method ('gromov_wasserstein_samples', 'sliced_wasserstein',
89
'max_sliced_wasserstein')
90
- reg: float, regularization parameter
91
- numItermax: int, maximum iterations
92
- stopThr: float, convergence threshold
93
- verbose: bool, print information
94
- log: bool, return optimization log
95
96
Returns:
97
- transport plan or distance depending on method, or (result, log) if log=True
98
"""
99
```
100
101
## Solver Configuration
102
103
### Automatic Method Selection
104
105
The unified solvers use intelligent method selection based on problem characteristics:
106
107
**Standard OT (`solve`):**
108
- **Small problems** (< 1000 samples): Exact EMD solver
109
- **Medium problems** with regularization: Sinkhorn variants
110
- **Large problems**: Stabilized Sinkhorn or epsilon-scaling
111
- **Sparse problems**: Screenkhorn or greedy Sinkhorn
112
113
**Gromov-Wasserstein (`solve_gromov`):**
114
- **Small problems**: Exact conditional gradient
115
- **Regularized problems**: Entropic Gromov-Wasserstein
116
- **Large structured problems**: Proximal point methods
117
- **Mixed structure-feature**: Automatic Fused GW detection
118
119
**Sampling-based (`solve_sample`):**
120
- **High-dimensional data**: Sliced Wasserstein approaches
121
- **Large-scale structured data**: Sampled Gromov-Wasserstein
122
- **GPU acceleration**: Backend-optimized sampling
123
124
### Common Parameters
125
126
All unified solvers support common configuration parameters:
127
128
```python { .api }
129
# Regularization types
130
reg_type = 'entropy' # Entropic regularization (Sinkhorn-type)
131
reg_type = 'l2' # L2 regularization (smooth OT)
132
reg_type = 'kl' # KL divergence regularization
133
reg_type = 'tv' # Total variation regularization
134
135
# Method selection
136
method = 'auto' # Automatic method selection
137
method = 'exact' # Force exact methods when possible
138
method = 'regularized' # Force regularized methods
139
method = 'fast' # Prioritize speed over accuracy
140
141
# Convergence control
142
stopThr = 1e-6 # Convergence threshold
143
numItermax = 1000 # Maximum iterations
144
verbose = True # Print solver progress
145
log = True # Return detailed optimization log
146
```
147
148
## Usage Examples
149
150
### Basic Optimal Transport
151
152
```python
153
import ot
154
import numpy as np
155
156
# Create distributions
157
n, m = 100, 120
158
a = ot.utils.unif(n)
159
b = ot.utils.unif(m)
160
X = np.random.randn(n, 2)
161
Y = np.random.randn(m, 2)
162
M = ot.dist(X, Y)
163
164
# Solve with automatic method selection
165
plan = ot.solve(a, b, M, reg=0.1, method='auto', verbose=True)
166
167
# Solve exact transport (automatically uses EMD)
168
plan_exact = ot.solve(a, b, M, method='exact')
169
170
# Solve with specific regularization
171
plan_l2 = ot.solve(a, b, M, reg=0.01, reg_type='l2', method='smooth')
172
```
173
174
### Gromov-Wasserstein Problems
175
176
```python
177
# Create structured data
178
n_s, n_t = 50, 60
179
C1 = ot.dist(np.random.randn(n_s, 2)) # Source structure
180
C2 = ot.dist(np.random.randn(n_t, 2)) # Target structure
181
182
# Pure Gromov-Wasserstein
183
plan_gw = ot.solve_gromov(C1, C2, reg=0.1, method='auto')
184
185
# Fused Gromov-Wasserstein with features
186
X_s = np.random.randn(n_s, 3)
187
X_t = np.random.randn(n_t, 3)
188
M_features = ot.dist(X_s, X_t)
189
190
plan_fgw = ot.solve_gromov(
191
C1, C2, M=M_features, alpha=0.5,
192
reg=0.1, method='auto', verbose=True
193
)
194
```
195
196
### Large-Scale Sampling
197
198
```python
199
# Large-scale problem with sampling
200
n_large = 10000
201
X_s_large = np.random.randn(n_large, 100)
202
X_t_large = np.random.randn(n_large, 100)
203
204
# Use sampling-based solver
205
result = ot.solve_sample(
206
X_s_large, X_t_large,
207
method='sliced_wasserstein',
208
numItermax=50,
209
verbose=True,
210
log=True
211
)
212
213
distance, log_dict = result
214
print(f"Sliced Wasserstein distance: {distance}")
215
```
216
217
### Backend Integration
218
219
```python
220
# Automatic backend detection and GPU acceleration
221
import torch
222
223
# PyTorch tensors (automatically detected)
224
a_torch = torch.ones(100) / 100
225
b_torch = torch.ones(120) / 120
226
M_torch = torch.randn(100, 120)
227
228
# Solver automatically uses PyTorch backend
229
plan_torch = ot.solve(a_torch, b_torch, M_torch, reg=0.1, method='auto')
230
231
# Force specific backend
232
with ot.backend.jax_backend():
233
plan_jax = ot.solve(a, b, M, reg=0.1, method='sinkhorn')
234
```
235
236
## Import Statements
237
238
```python
239
import ot
240
from ot import solve, solve_gromov, solve_sample
241
from ot.solvers import solve, solve_gromov, solve_sample
242
```