or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

factored-transport.mddocs/

0

# Factored Optimal Transport

1

2

Factored optimal transport provides efficient algorithms for problems with special structure that allows factorization of the transport plan. This approach significantly reduces computational complexity for large-scale problems with structured data.

3

4

## Capabilities

5

6

### Factored Optimal Transport Solver

7

8

Solve optimal transport problems using factored decomposition approaches.

9

10

```python { .api }

11

def factored_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, **kwargs):

12

"""

13

Solve optimal transport using factored decomposition.

14

15

This method exploits structure in the data to factorize the transport plan,

16

reducing computational complexity from O(n²) to approximately O(n·k) where

17

k << n is the factorization rank.

18

19

Parameters:

20

- Xa: array-like, shape (n_samples_a, n_features), source samples

21

- Xb: array-like, shape (n_samples_b, n_features), target samples

22

- a: array-like, shape (n_samples_a,), source distribution (uniform if None)

23

- b: array-like, shape (n_samples_b,), target distribution (uniform if None)

24

- verbose: bool, print optimization information

25

- log: bool, return optimization log and factorization details

26

27

Returns:

28

- transport plan matrix or (plan, log) if log=True

29

"""

30

```

31

32

## Factorization Approaches

33

34

### Low-Rank Transport Plans

35

36

Many optimal transport problems exhibit low-rank structure that can be exploited:

37

38

**Standard Transport Plan**: `γ ∈ R^(n×m)` with O(nm) complexity

39

**Factored Transport Plan**: `γ ≈ UV^T` where `U ∈ R^(n×k)`, `V ∈ R^(m×k)` with O((n+m)k) complexity

40

41

### Structured Data Scenarios

42

43

Factored transport is particularly effective for:

44

45

1. **Gaussian Distributions**: Natural low-rank structure in transport plans

46

2. **Time Series**: Temporal structure enables efficient factorization

47

3. **Images**: Spatial correlation allows patch-based factorization

48

4. **Graph Data**: Community structure supports block-wise transport

49

5. **High-dimensional Data**: Manifold structure enables dimensionality reduction

50

51

## Usage Examples

52

53

### Basic Factored Transport

54

55

```python

56

import ot

57

import numpy as np

58

import matplotlib.pyplot as plt

59

60

# Create high-dimensional data with low-rank structure

61

n_source, n_target = 1000, 1200

62

n_features = 50

63

rank = 5

64

65

# Generate low-rank source data

66

U_source = np.random.randn(n_source, rank)

67

V_source = np.random.randn(rank, n_features)

68

Xa = U_source @ V_source + 0.1 * np.random.randn(n_source, n_features)

69

70

# Generate related target data

71

U_target = U_source + 0.5 * np.random.randn(n_source, rank)

72

U_target = np.vstack([U_target, np.random.randn(n_target - n_source, rank)])

73

V_target = V_source + 0.3 * np.random.randn(rank, n_features)

74

Xb = U_target @ V_target + 0.1 * np.random.randn(n_target, n_features)

75

76

# Solve using factored transport

77

plan_factored = ot.factored_optimal_transport(Xa, Xb, verbose=True, log=False)

78

79

print(f"Transport plan shape: {plan_factored.shape}")

80

print(f"Plan sparsity: {np.sum(plan_factored > 1e-8) / plan_factored.size:.4f}")

81

```

82

83

### Comparison with Standard Methods

84

85

```python

86

# Compare computational efficiency

87

import time

88

89

# Standard optimal transport (for smaller problem)

90

n_small = 200

91

Xa_small = Xa[:n_small]

92

Xb_small = Xb[:n_small]

93

94

a_small = ot.utils.unif(n_small)

95

b_small = ot.utils.unif(n_small)

96

97

# Standard EMD

98

start_time = time.time()

99

M_small = ot.dist(Xa_small, Xb_small)

100

plan_emd = ot.emd(a_small, b_small, M_small)

101

time_emd = time.time() - start_time

102

103

# Sinkhorn

104

start_time = time.time()

105

plan_sinkhorn = ot.sinkhorn(a_small, b_small, M_small, reg=0.1)

106

time_sinkhorn = time.time() - start_time

107

108

# Factored transport (same problem)

109

start_time = time.time()

110

plan_factored_small = ot.factored_optimal_transport(Xa_small, Xb_small)

111

time_factored = time.time() - start_time

112

113

print(f"Timing comparison (n={n_small}):")

114

print(f" EMD: {time_emd:.4f}s")

115

print(f" Sinkhorn: {time_sinkhorn:.4f}s")

116

print(f" Factored: {time_factored:.4f}s")

117

118

# Large-scale problem (only factored transport feasible)

119

print(f"\\nLarge-scale problem (n_source={n_source}, n_target={n_target}):")

120

start_time = time.time()

121

plan_large = ot.factored_optimal_transport(Xa, Xb, verbose=False)

122

time_large = time.time() - start_time

123

print(f" Factored transport: {time_large:.4f}s")

124

```

125

126

### Gaussian Mixture Example

127

128

```python

129

# Example with Gaussian mixtures (natural factorization)

130

from sklearn.mixture import GaussianMixture

131

132

# Create Gaussian mixture data

133

n_components = 3

134

n_samples_per_comp = 300

135

136

# Source mixture

137

gmm_source = GaussianMixture(n_components=n_components, random_state=42)

138

Xa_gmm = np.vstack([

139

np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], n_samples_per_comp),

140

np.random.multivariate_normal([3, 0], [[1, -0.3], [-0.3, 1]], n_samples_per_comp),

141

np.random.multivariate_normal([1.5, 3], [[0.8, 0.2], [0.2, 0.8]], n_samples_per_comp)

142

])

143

144

# Target mixture (shifted and rotated)

145

theta = np.pi / 6

146

R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])

147

Xb_gmm = np.vstack([

148

np.random.multivariate_normal([1, 1], [[1.2, 0.4], [0.4, 1.2]], n_samples_per_comp),

149

np.random.multivariate_normal([4, 1], [[1, -0.4], [-0.4, 1]], n_samples_per_comp),

150

np.random.multivariate_normal([2.5, 4], [[0.9, 0.3], [0.3, 0.9]], n_samples_per_comp)

151

]) @ R.T

152

153

# Solve with factored transport

154

plan_gmm, log_gmm = ot.factored_optimal_transport(

155

Xa_gmm, Xb_gmm,

156

verbose=True,

157

log=True

158

)

159

160

# Visualize results

161

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

162

163

# Source data

164

axes[0].scatter(Xa_gmm[:, 0], Xa_gmm[:, 1], alpha=0.6, c='blue')

165

axes[0].set_title('Source Distribution')

166

axes[0].set_aspect('equal')

167

168

# Target data

169

axes[1].scatter(Xb_gmm[:, 0], Xb_gmm[:, 1], alpha=0.6, c='red')

170

axes[1].set_title('Target Distribution')

171

axes[1].set_aspect('equal')

172

173

# Transport plan visualization

174

im = axes[2].imshow(plan_gmm, cmap='Blues', aspect='auto')

175

axes[2].set_xlabel('Target samples')

176

axes[2].set_ylabel('Source samples')

177

axes[2].set_title('Factored Transport Plan')

178

plt.colorbar(im, ax=axes[2])

179

180

plt.tight_layout()

181

plt.show()

182

183

if 'factorization_rank' in log_gmm:

184

print(f"Effective factorization rank: {log_gmm['factorization_rank']}")

185

```

186

187

### Time Series Transport

188

189

```python

190

# Example with time series data

191

from sklearn.decomposition import PCA

192

193

# Generate time series with shared temporal patterns

194

t = np.linspace(0, 10, 100)

195

n_series_source = 200

196

n_series_target = 250

197

198

# Base temporal patterns

199

patterns = np.array([

200

np.sin(t),

201

np.cos(t),

202

np.sin(2*t),

203

np.exp(-t/5) * np.sin(t)

204

]).T

205

206

# Source time series (linear combinations of patterns)

207

weights_source = np.random.exponential(1, (n_series_source, 4))

208

Xa_ts = weights_source @ patterns.T + 0.1 * np.random.randn(n_series_source, len(t))

209

210

# Target time series (shifted patterns)

211

weights_target = np.random.exponential(1.2, (n_series_target, 4))

212

patterns_shifted = np.roll(patterns, 5, axis=0) # Temporal shift

213

Xb_ts = weights_target @ patterns_shifted.T + 0.1 * np.random.randn(n_series_target, len(t))

214

215

# Apply PCA preprocessing to enhance structure

216

pca = PCA(n_components=10)

217

Xa_ts_pca = pca.fit_transform(Xa_ts)

218

Xb_ts_pca = pca.transform(Xb_ts)

219

220

# Factored transport on time series

221

plan_ts = ot.factored_optimal_transport(Xa_ts_pca, Xb_ts_pca, verbose=True)

222

223

# Visualize sample time series and their transport

224

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

225

226

# Sample source time series

227

for i in range(5):

228

axes[0, 0].plot(t, Xa_ts[i], alpha=0.7)

229

axes[0, 0].set_title('Sample Source Time Series')

230

axes[0, 0].set_xlabel('Time')

231

232

# Sample target time series

233

for i in range(5):

234

axes[0, 1].plot(t, Xb_ts[i], alpha=0.7)

235

axes[0, 1].set_title('Sample Target Time Series')

236

axes[0, 1].set_xlabel('Time')

237

238

# PCA representation

239

axes[1, 0].scatter(Xa_ts_pca[:, 0], Xa_ts_pca[:, 1], alpha=0.6, label='Source')

240

axes[1, 0].scatter(Xb_ts_pca[:, 0], Xb_ts_pca[:, 1], alpha=0.6, label='Target')

241

axes[1, 0].set_xlabel('PC1')

242

axes[1, 0].set_ylabel('PC2')

243

axes[1, 0].set_title('PCA Representation')

244

axes[1, 0].legend()

245

246

# Transport plan sparsity pattern

247

axes[1, 1].spy(plan_ts > 1e-6, markersize=0.1)

248

axes[1, 1].set_title('Transport Plan Sparsity')

249

axes[1, 1].set_xlabel('Target series')

250

axes[1, 1].set_ylabel('Source series')

251

252

plt.tight_layout()

253

plt.show()

254

```

255

256

## Import Statements

257

258

```python

259

import ot

260

from ot import factored_optimal_transport

261

from ot.factored import factored_optimal_transport

262

```