or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

unified-solvers.mddocs/

0

# Unified Solvers

1

2

High-level unified interface for optimal transport solvers, providing a consistent API across different problem types and algorithms. These solvers automatically select appropriate methods based on problem characteristics and user preferences.

3

4

## Capabilities

5

6

### General Optimal Transport Solver

7

8

Unified solver for standard optimal transport problems with automatic algorithm selection.

9

10

```python { .api }

11

def solve(a, b, M, reg=None, reg_type='entropy', method='auto', numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):

12

"""

13

General optimal transport solver with automatic method selection.

14

15

This function provides a unified interface to various OT solvers, automatically

16

selecting the most appropriate algorithm based on problem size, regularization,

17

and other parameters.

18

19

Parameters:

20

- a: array-like, source distribution

21

- b: array-like, target distribution

22

- M: array-like, cost matrix

23

- reg: float, regularization parameter (None for exact transport)

24

- reg_type: str, regularization type ('entropy', 'l2', 'kl', 'tv')

25

- method: str, solver method ('auto', 'emd', 'sinkhorn', 'sinkhorn_log',

26

'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling', 'smooth')

27

- numItermax: int, maximum number of iterations

28

- stopThr: float, convergence threshold

29

- verbose: bool, print solver information

30

- log: bool, return optimization log

31

32

Returns:

33

- transport plan matrix or (plan, log) if log=True

34

"""

35

```

36

37

### Gromov-Wasserstein Solver

38

39

Unified solver for Gromov-Wasserstein problems and variants.

40

41

```python { .api }

42

def solve_gromov(C1, C2, p=None, q=None, M=None, alpha=0.0, reg=None, reg_type='entropy', method='auto', loss_fun='square_loss', armijo=False, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):

43

"""

44

General Gromov-Wasserstein solver with automatic method selection.

45

46

Solves Gromov-Wasserstein and Fused Gromov-Wasserstein problems using

47

appropriate algorithms based on regularization and problem characteristics.

48

49

Parameters:

50

- C1: array-like, cost matrix for source space

51

- C2: array-like, cost matrix for target space

52

- p: array-like, source distribution (uniform if None)

53

- q: array-like, target distribution (uniform if None)

54

- M: array-like, feature cost matrix (for Fused GW, None for pure GW)

55

- alpha: float, trade-off parameter between structure and features (0=pure GW, 1=pure Wasserstein)

56

- reg: float, regularization parameter (None for exact)

57

- reg_type: str, regularization type ('entropy', 'l2')

58

- method: str, solver method ('auto', 'conditional_gradient', 'proximal_point', 'frank_wolfe')

59

- loss_fun: str or callable, loss function ('square_loss', 'kl_loss')

60

- armijo: bool, use Armijo line search

61

- numItermax: int, maximum iterations

62

- stopThr: float, convergence threshold

63

- verbose: bool, print information

64

- log: bool, return optimization log

65

66

Returns:

67

- transport plan matrix or (plan, log) if log=True

68

"""

69

```

70

71

### Sampling-based Solver

72

73

Solver for large-scale problems using sampling approaches.

74

75

```python { .api }

76

def solve_sample(X_s, X_t, a=None, b=None, method='gromov_wasserstein_samples', reg=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs):

77

"""

78

Solve optimal transport using sampling-based methods.

79

80

Efficient solver for large-scale problems using sampling techniques

81

to approximate optimal transport distances and plans.

82

83

Parameters:

84

- X_s: array-like, source samples (n_samples_s, n_features)

85

- X_t: array-like, target samples (n_samples_t, n_features)

86

- a: array-like, source weights (uniform if None)

87

- b: array-like, target weights (uniform if None)

88

- method: str, sampling method ('gromov_wasserstein_samples', 'sliced_wasserstein',

89

'max_sliced_wasserstein')

90

- reg: float, regularization parameter

91

- numItermax: int, maximum iterations

92

- stopThr: float, convergence threshold

93

- verbose: bool, print information

94

- log: bool, return optimization log

95

96

Returns:

97

- transport plan or distance depending on method, or (result, log) if log=True

98

"""

99

```

100

101

## Solver Configuration

102

103

### Automatic Method Selection

104

105

The unified solvers use intelligent method selection based on problem characteristics:

106

107

**Standard OT (`solve`):**

108

- **Small problems** (< 1000 samples): Exact EMD solver

109

- **Medium problems** with regularization: Sinkhorn variants

110

- **Large problems**: Stabilized Sinkhorn or epsilon-scaling

111

- **Sparse problems**: Screenkhorn or greedy Sinkhorn

112

113

**Gromov-Wasserstein (`solve_gromov`):**

114

- **Small problems**: Exact conditional gradient

115

- **Regularized problems**: Entropic Gromov-Wasserstein

116

- **Large structured problems**: Proximal point methods

117

- **Mixed structure-feature**: Automatic Fused GW detection

118

119

**Sampling-based (`solve_sample`):**

120

- **High-dimensional data**: Sliced Wasserstein approaches

121

- **Large-scale structured data**: Sampled Gromov-Wasserstein

122

- **GPU acceleration**: Backend-optimized sampling

123

124

### Common Parameters

125

126

All unified solvers support common configuration parameters:

127

128

```python { .api }

129

# Regularization types

130

reg_type = 'entropy' # Entropic regularization (Sinkhorn-type)

131

reg_type = 'l2' # L2 regularization (smooth OT)

132

reg_type = 'kl' # KL divergence regularization

133

reg_type = 'tv' # Total variation regularization

134

135

# Method selection

136

method = 'auto' # Automatic method selection

137

method = 'exact' # Force exact methods when possible

138

method = 'regularized' # Force regularized methods

139

method = 'fast' # Prioritize speed over accuracy

140

141

# Convergence control

142

stopThr = 1e-6 # Convergence threshold

143

numItermax = 1000 # Maximum iterations

144

verbose = True # Print solver progress

145

log = True # Return detailed optimization log

146

```

147

148

## Usage Examples

149

150

### Basic Optimal Transport

151

152

```python

153

import ot

154

import numpy as np

155

156

# Create distributions

157

n, m = 100, 120

158

a = ot.utils.unif(n)

159

b = ot.utils.unif(m)

160

X = np.random.randn(n, 2)

161

Y = np.random.randn(m, 2)

162

M = ot.dist(X, Y)

163

164

# Solve with automatic method selection

165

plan = ot.solve(a, b, M, reg=0.1, method='auto', verbose=True)

166

167

# Solve exact transport (automatically uses EMD)

168

plan_exact = ot.solve(a, b, M, method='exact')

169

170

# Solve with specific regularization

171

plan_l2 = ot.solve(a, b, M, reg=0.01, reg_type='l2', method='smooth')

172

```

173

174

### Gromov-Wasserstein Problems

175

176

```python

177

# Create structured data

178

n_s, n_t = 50, 60

179

C1 = ot.dist(np.random.randn(n_s, 2)) # Source structure

180

C2 = ot.dist(np.random.randn(n_t, 2)) # Target structure

181

182

# Pure Gromov-Wasserstein

183

plan_gw = ot.solve_gromov(C1, C2, reg=0.1, method='auto')

184

185

# Fused Gromov-Wasserstein with features

186

X_s = np.random.randn(n_s, 3)

187

X_t = np.random.randn(n_t, 3)

188

M_features = ot.dist(X_s, X_t)

189

190

plan_fgw = ot.solve_gromov(

191

C1, C2, M=M_features, alpha=0.5,

192

reg=0.1, method='auto', verbose=True

193

)

194

```

195

196

### Large-Scale Sampling

197

198

```python

199

# Large-scale problem with sampling

200

n_large = 10000

201

X_s_large = np.random.randn(n_large, 100)

202

X_t_large = np.random.randn(n_large, 100)

203

204

# Use sampling-based solver

205

result = ot.solve_sample(

206

X_s_large, X_t_large,

207

method='sliced_wasserstein',

208

numItermax=50,

209

verbose=True,

210

log=True

211

)

212

213

distance, log_dict = result

214

print(f"Sliced Wasserstein distance: {distance}")

215

```

216

217

### Backend Integration

218

219

```python

220

# Automatic backend detection and GPU acceleration

221

import torch

222

223

# PyTorch tensors (automatically detected)

224

a_torch = torch.ones(100) / 100

225

b_torch = torch.ones(120) / 120

226

M_torch = torch.randn(100, 120)

227

228

# Solver automatically uses PyTorch backend

229

plan_torch = ot.solve(a_torch, b_torch, M_torch, reg=0.1, method='auto')

230

231

# Force specific backend

232

with ot.backend.jax_backend():

233

plan_jax = ot.solve(a, b, M, reg=0.1, method='sinkhorn')

234

```

235

236

## Import Statements

237

238

```python

239

import ot

240

from ot import solve, solve_gromov, solve_sample

241

from ot.solvers import solve, solve_gromov, solve_sample

242

```