or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

stochastic-solvers.mddocs/

0

# Stochastic Solvers

1

2

Stochastic algorithms for large-scale optimal transport problems using Stochastic Average Gradient (SAG) and Stochastic Gradient Descent (SGD) methods. These solvers are particularly effective for problems with large numbers of samples where exact methods become computationally prohibitive.

3

4

## Capabilities

5

6

### SAG-based Entropic Transport

7

8

Stochastic Average Gradient methods for solving regularized optimal transport problems.

9

10

```python { .api }

11

def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state=None):

12

"""

13

Solve entropic regularized OT with Stochastic Average Gradient algorithm.

14

15

Parameters:

16

- a: array-like, source distribution

17

- b: array-like, target distribution

18

- M: array-like, cost matrix

19

- reg: float, regularization parameter

20

- numItermax: int, maximum number of iterations

21

- lr: float, learning rate (auto-computed if None)

22

- random_state: int, random seed for reproducibility

23

24

Returns:

25

- transport plan matrix

26

"""

27

28

def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=1e-4, random_state=None):

29

"""

30

Solve entropic regularized OT with averaged SGD.

31

32

Parameters:

33

- a: array-like, source distribution

34

- b: array-like, target distribution

35

- M: array-like, cost matrix

36

- reg: float, regularization parameter

37

- numItermax: int, maximum number of iterations

38

- lr: float, learning rate

39

- random_state: int, random seed

40

41

Returns:

42

- transport plan matrix

43

"""

44

```

45

46

### SGD-based Dual Solvers

47

48

Stochastic gradient descent methods for dual optimal transport formulations.

49

50

```python { .api }

51

def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax=10000, lr=0.1, log=False):

52

"""

53

Solve entropic regularized OT using SGD on dual formulation.

54

55

Parameters:

56

- a: array-like, source distribution

57

- b: array-like, target distribution

58

- M: array-like, cost matrix

59

- reg: float, regularization parameter

60

- batch_size: int, batch size for stochastic updates

61

- numItermax: int, maximum iterations

62

- lr: float, learning rate

63

- log: bool, return optimization log

64

65

Returns:

66

- dual variables (u, v) or ((u, v), log) if log=True

67

"""

68

69

def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, log=False):

70

"""

71

Solve dual entropic regularized OT with SGD.

72

73

Parameters:

74

- a: array-like, source distribution

75

- b: array-like, target distribution

76

- M: array-like, cost matrix

77

- reg: float, regularization parameter

78

- batch_size: int, batch size for gradient computation

79

- numItermax: int, maximum iterations

80

- lr: float, learning rate

81

- log: bool, return optimization log

82

83

Returns:

84

- dual variables (u, v) or ((u, v), log) if log=True

85

"""

86

```

87

88

### Semi-Dual Entropic Solvers

89

90

Semi-dual formulation solvers for enhanced computational efficiency.

91

92

```python { .api }

93

def solve_semi_dual_entropic(a, b, M, reg, method='SLSQP', numItermax=10000, lr=1, log=False):

94

"""

95

Solve semi-dual entropic regularized optimal transport.

96

97

Parameters:

98

- a: array-like, source distribution

99

- b: array-like, target distribution

100

- M: array-like, cost matrix

101

- reg: float, regularization parameter

102

- method: str, optimization method ('SLSQP', 'L-BFGS-B')

103

- numItermax: int, maximum iterations

104

- lr: float, learning rate

105

- log: bool, return optimization log

106

107

Returns:

108

- dual variable u or (u, log) if log=True

109

"""

110

```

111

112

### Empirical Loss Functions

113

114

Loss and plan computation for empirical optimal transport with dual variables.

115

116

```python { .api }

117

def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"):

118

"""

119

Compute dual entropic loss for empirical OT.

120

121

Parameters:

122

- u: array-like, dual variable for source samples

123

- v: array-like, dual variable for target samples

124

- xs: array-like, source samples

125

- xt: array-like, target samples

126

- reg: float, regularization parameter

127

- ws: array-like, source weights (uniform if None)

128

- wt: array-like, target weights (uniform if None)

129

- metric: str, distance metric ('sqeuclidean', 'euclidean', etc.)

130

131

Returns:

132

- dual loss value

133

"""

134

135

def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"):

136

"""

137

Compute transport plan from dual variables for empirical OT.

138

139

Parameters:

140

- u: array-like, dual variable for source samples

141

- v: array-like, dual variable for target samples

142

- xs: array-like, source samples

143

- xt: array-like, target samples

144

- reg: float, regularization parameter

145

- ws: array-like, source weights

146

- wt: array-like, target weights

147

- metric: str, distance metric

148

149

Returns:

150

- transport plan matrix

151

"""

152

153

def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"):

154

"""

155

Compute quadratic dual loss for empirical OT.

156

157

Parameters:

158

- u: array-like, dual variable for source samples

159

- v: array-like, dual variable for target samples

160

- xs: array-like, source samples

161

- xt: array-like, target samples

162

- reg: float, regularization parameter

163

- ws: array-like, source weights

164

- wt: array-like, target weights

165

- metric: str, distance metric

166

167

Returns:

168

- quadratic dual loss value

169

"""

170

171

def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"):

172

"""

173

Compute transport plan from dual variables with quadratic regularization.

174

175

Parameters:

176

- Same as loss_dual_quadratic

177

178

Returns:

179

- transport plan matrix

180

"""

181

```

182

183

### Transform Functions

184

185

Utility functions for coordinate transformations in stochastic algorithms.

186

187

```python { .api }

188

def c_transform_entropic(b, M, reg, beta):

189

"""

190

Compute c-transform for entropic regularization.

191

192

Parameters:

193

- b: array-like, target distribution

194

- M: array-like, cost matrix

195

- reg: float, regularization parameter

196

- beta: array-like, dual variable

197

198

Returns:

199

- c-transform result

200

"""

201

```

202

203

## Usage Example

204

205

```python

206

import ot

207

import ot.stochastic

208

import numpy as np

209

210

# Generate large-scale data

211

n_samples = 10000

212

n_features = 50

213

X_s = np.random.randn(n_samples, n_features)

214

X_t = np.random.randn(n_samples, n_features)

215

216

# Create distributions

217

a = ot.utils.unif(n_samples)

218

b = ot.utils.unif(n_samples)

219

M = ot.dist(X_s, X_t)

220

221

# Solve with SAG algorithm

222

reg = 0.1

223

plan_sag = ot.stochastic.sag_entropic_transport(a, b, M, reg, numItermax=1000)

224

225

# Solve with SGD on dual formulation

226

batch_size = 100

227

u, v = ot.stochastic.sgd_entropic_regularization(

228

a, b, M, reg, batch_size, numItermax=5000, lr=0.01

229

)

230

231

# Compute loss and plan from dual variables

232

loss = ot.stochastic.loss_dual_entropic(u, v, X_s, X_t, reg)

233

plan = ot.stochastic.plan_dual_entropic(u, v, X_s, X_t, reg)

234

```

235

236

## Import Statements

237

238

```python

239

import ot.stochastic

240

from ot.stochastic import sag_entropic_transport, sgd_entropic_regularization

241

from ot.stochastic import loss_dual_entropic, plan_dual_entropic

242

```