or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

weak-transport.mddocs/

0

# Weak Optimal Transport

1

2

Weak optimal transport provides a relaxed formulation of the classical optimal transport problem where the transport plan minimizes displacement variance rather than total transport cost. This approach is particularly useful for applications where preserving local structure is more important than minimizing global transport costs.

3

4

## Capabilities

5

6

### Weak Optimal Transport Solver

7

8

Solve the weak optimal transport problem between empirical distributions.

9

10

```python { .api }

11

def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):

12

"""

13

Solve weak optimal transport problem between two empirical distributions.

14

15

The weak OT problem minimizes the displacement variance:

16

γ = argmin_γ Σ_i a_i (X^a_i - (1/a_i) Σ_j γ_ij X^b_j)²

17

18

subject to standard transport constraints:

19

- γ 1 = a (source marginal constraint)

20

- γ^T 1 = b (target marginal constraint)

21

- γ ≥ 0 (non-negativity)

22

23

Parameters:

24

- Xa: array-like, shape (n_samples_a, n_features), source samples

25

- Xb: array-like, shape (n_samples_b, n_features), target samples

26

- a: array-like, shape (n_samples_a,), source distribution (uniform if None)

27

- b: array-like, shape (n_samples_b,), target distribution (uniform if None)

28

- verbose: bool, print optimization information

29

- log: bool, return optimization log

30

- G0: array-like, initial transport plan (None for uniform initialization)

31

32

Returns:

33

- transport plan matrix or (plan, log) if log=True

34

"""

35

```

36

37

## Theory and Applications

38

39

### Weak vs Classical Optimal Transport

40

41

**Classical Optimal Transport:**

42

- Minimizes total transport cost: `Σ_ij γ_ij C_ij`

43

- Optimal for minimizing global displacement

44

- Can create large local distortions

45

46

**Weak Optimal Transport:**

47

- Minimizes displacement variance: `Σ_i a_i ||X^a_i - barycenter_i||²`

48

- Preserves local neighborhood structure

49

- Better for shape matching and morphing applications

50

51

### Key Properties

52

53

1. **Local Structure Preservation**: Maintains local relationships in source space

54

2. **Barycentric Transport**: Each source point maps to a barycenter of target points

55

3. **Variance Minimization**: Reduces spread of transported mass around barycenters

56

4. **Conditional Gradient**: Efficiently solved using Frank-Wolfe type algorithms

57

58

## Usage Examples

59

60

### Basic Weak Transport

61

62

```python

63

import ot

64

import numpy as np

65

import matplotlib.pyplot as plt

66

67

# Create 2D point clouds

68

n_source, n_target = 100, 120

69

np.random.seed(42)

70

71

# Source: circle

72

theta_s = np.linspace(0, 2*np.pi, n_source)

73

Xa = np.column_stack([np.cos(theta_s), np.sin(theta_s)])

74

Xa += 0.1 * np.random.randn(n_source, 2) # Add noise

75

76

# Target: ellipse

77

theta_t = np.linspace(0, 2*np.pi, n_target)

78

Xb = np.column_stack([2*np.cos(theta_t), 0.5*np.sin(theta_t)])

79

Xb += 0.1 * np.random.randn(n_target, 2)

80

81

# Solve weak optimal transport

82

plan_weak = ot.weak_optimal_transport(Xa, Xb, verbose=True, log=False)

83

84

print(f"Transport plan shape: {plan_weak.shape}")

85

print(f"Plan sum: {np.sum(plan_weak):.6f}")

86

print(f"Source marginal error: {np.max(np.abs(np.sum(plan_weak, axis=1) - 1/n_source)):.6f}")

87

```

88

89

### Comparison with Classical Transport

90

91

```python

92

# Compare weak vs classical optimal transport

93

a = ot.utils.unif(n_source)

94

b = ot.utils.unif(n_target)

95

96

# Classical transport

97

M = ot.dist(Xa, Xb)

98

plan_classical = ot.emd(a, b, M)

99

100

# Weak transport

101

plan_weak = ot.weak_optimal_transport(Xa, Xb, a, b)

102

103

# Visualize differences

104

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

105

106

# Source and target

107

axes[0].scatter(Xa[:, 0], Xa[:, 1], c='blue', alpha=0.6, label='Source')

108

axes[0].scatter(Xb[:, 0], Xb[:, 1], c='red', alpha=0.6, label='Target')

109

axes[0].set_title('Source and Target')

110

axes[0].legend()

111

112

# Classical transport visualization

113

for i in range(0, n_source, 5): # Show subset of connections

114

for j in range(n_target):

115

if plan_classical[i, j] > 0.01:

116

axes[1].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]],

117

'k-', alpha=plan_classical[i, j]*10, linewidth=0.5)

118

axes[1].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)

119

axes[1].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)

120

axes[1].set_title('Classical OT')

121

122

# Weak transport visualization

123

for i in range(0, n_source, 5):

124

for j in range(n_target):

125

if plan_weak[i, j] > 0.01:

126

axes[2].plot([Xa[i, 0], Xb[j, 0]], [Xa[i, 1], Xb[j, 1]],

127

'g-', alpha=plan_weak[i, j]*10, linewidth=0.5)

128

axes[2].scatter(Xa[:, 0], Xa[:, 1], c='blue', s=20)

129

axes[2].scatter(Xb[:, 0], Xb[:, 1], c='red', s=20)

130

axes[2].set_title('Weak OT')

131

132

plt.tight_layout()

133

plt.show()

134

```

135

136

### Shape Morphing Application

137

138

```python

139

# Use weak transport for shape morphing

140

def interpolate_shapes(Xa, Xb, t=0.5):

141

"""Interpolate between shapes using weak transport."""

142

plan = ot.weak_optimal_transport(Xa, Xb)

143

144

# Compute barycenters for each source point

145

barycenters = np.zeros_like(Xa)

146

for i in range(len(Xa)):

147

if np.sum(plan[i, :]) > 0:

148

weights = plan[i, :] / np.sum(plan[i, :])

149

barycenters[i] = np.average(Xb, weights=weights, axis=0)

150

else:

151

barycenters[i] = Xa[i] # No transport for this point

152

153

# Linear interpolation

154

interpolated = (1 - t) * Xa + t * barycenters

155

return interpolated

156

157

# Create morphing sequence

158

n_steps = 10

159

morphing_sequence = []

160

for i in range(n_steps + 1):

161

t = i / n_steps

162

shape_t = interpolate_shapes(Xa, Xb, t)

163

morphing_sequence.append(shape_t)

164

165

# Visualize morphing

166

fig, axes = plt.subplots(2, 6, figsize=(18, 6))

167

axes = axes.flatten()

168

169

for i, shape in enumerate(morphing_sequence[::1]): # Show every shape

170

if i < len(axes):

171

axes[i].scatter(shape[:, 0], shape[:, 1], c='purple', alpha=0.7, s=20)

172

axes[i].set_title(f't = {i/(len(morphing_sequence)-1):.1f}')

173

axes[i].set_aspect('equal')

174

axes[i].grid(True, alpha=0.3)

175

176

plt.tight_layout()

177

plt.show()

178

```

179

180

### Advanced Usage with Custom Parameters

181

182

```python

183

# Advanced usage with custom initialization and logging

184

import time

185

186

# Custom initial transport plan (e.g., based on nearest neighbors)

187

from sklearn.neighbors import NearestNeighbors

188

nn = NearestNeighbors(n_neighbors=3)

189

nn.fit(Xb)

190

distances, indices = nn.kneighbors(Xa)

191

192

# Create sparse initialization

193

G0 = np.zeros((n_source, n_target))

194

for i in range(n_source):

195

for j, idx in enumerate(indices[i]):

196

G0[i, idx] = 1.0 / len(indices[i])

197

198

# Solve with custom initialization and detailed logging

199

start_time = time.time()

200

plan, log = ot.weak_optimal_transport(

201

Xa, Xb,

202

a=ot.utils.unif(n_source),

203

b=ot.utils.unif(n_target),

204

G0=G0,

205

verbose=True,

206

log=True,

207

numItermax=1000,

208

stopThr=1e-9

209

)

210

solve_time = time.time() - start_time

211

212

print(f"Solver completed in {solve_time:.3f} seconds")

213

print(f"Final objective: {log['loss'][-1]:.6f}")

214

print(f"Number of iterations: {len(log['loss'])}")

215

216

# Plot convergence

217

plt.figure(figsize=(10, 6))

218

plt.subplot(1, 2, 1)

219

plt.semilogy(log['loss'])

220

plt.xlabel('Iteration')

221

plt.ylabel('Objective value')

222

plt.title('Convergence of Weak OT')

223

plt.grid(True)

224

225

plt.subplot(1, 2, 2)

226

plt.imshow(plan, cmap='Blues', aspect='auto')

227

plt.colorbar()

228

plt.xlabel('Target samples')

229

plt.ylabel('Source samples')

230

plt.title('Transport Plan')

231

plt.tight_layout()

232

plt.show()

233

```

234

235

## Import Statements

236

237

```python

238

import ot

239

from ot import weak_optimal_transport

240

from ot.weak import weak_optimal_transport

241

```