or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdbackend-system.mddomain-adaptation.mdentropic-transport.mdfactored-transport.mdgromov-wasserstein.mdindex.mdlinear-programming.mdpartial-transport.mdregularization-path.mdsliced-wasserstein.mdsmooth-transport.mdstochastic-solvers.mdunbalanced-transport.mdunified-solvers.mdutilities.mdweak-transport.md

domain-adaptation.mddocs/

0

# Domain Adaptation

1

2

The `ot.da` module provides transport-based methods for domain adaptation in machine learning. These algorithms learn mappings between different domains (e.g., training and test distributions) by leveraging optimal transport theory, enabling knowledge transfer when source and target domains differ.

3

4

## Core Domain Adaptation Functions

5

6

### Label-Regularized Transport

7

8

```python { .api }

9

def ot.da.sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):

10

"""

11

Solve optimal transport with L1 label regularization using MM algorithm.

12

13

Incorporates label information from the source domain to guide the transport

14

by penalizing transport between samples with different labels. Uses the

15

Majorization-Minimization (MM) algorithmic framework.

16

17

Parameters:

18

- a: array-like, shape (n_samples_source,)

19

Source domain distribution (weights of source samples).

20

- labels_a: array-like, shape (n_samples_source,)

21

Labels of source domain samples (integer class labels).

22

- b: array-like, shape (n_samples_target,)

23

Target domain distribution (weights of target samples).

24

- M: array-like, shape (n_samples_source, n_samples_target)

25

Ground cost matrix between source and target samples.

26

- reg: float

27

Entropic regularization parameter for Sinkhorn algorithm.

28

- eta: float, default=0.1

29

Label regularization parameter. Higher values enforce stronger

30

alignment between samples of the same class.

31

- numItermax: int, default=10

32

Maximum number of outer MM iterations.

33

- numInnerItermax: int, default=200

34

Maximum iterations for inner Sinkhorn algorithm.

35

- stopInnerThr: float, default=1e-9

36

Convergence threshold for inner Sinkhorn iterations.

37

- verbose: bool, default=False

38

Print iteration information.

39

- log: bool, default=False

40

Return optimization log with convergence details.

41

42

Returns:

43

- transport_plan: ndarray, shape (n_samples_source, n_samples_target)

44

Optimal transport plan with label regularization.

45

- log: dict (if log=True)

46

Contains 'err': convergence errors, 'all_err': all errors history.

47

"""

48

49

def ot.da.sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, alpha=0.98, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):

50

"""

51

Solve optimal transport with L1-L2 group lasso regularization.

52

53

Combines L1 sparsity regularization with L2 group lasso to encourage

54

both sparsity and grouping in the transport plan according to class labels.

55

56

Parameters:

57

- a: array-like, shape (n_samples_source,)

58

Source distribution.

59

- labels_a: array-like, shape (n_samples_source,)

60

Source labels.

61

- b: array-like, shape (n_samples_target,)

62

Target distribution.

63

- M: array-like, shape (n_samples_source, n_samples_target)

64

Cost matrix.

65

- reg: float

66

Entropic regularization parameter.

67

- eta: float, default=0.1

68

L1 regularization parameter.

69

- alpha: float, default=0.98

70

Trade-off between L1 and L2 regularization (elastic net parameter).

71

- numItermax: int, default=10

72

Maximum outer iterations.

73

- numInnerItermax: int, default=200

74

Maximum inner iterations.

75

- stopInnerThr: float, default=1e-9

76

Inner convergence threshold.

77

- verbose: bool, default=False

78

- log: bool, default=False

79

80

Returns:

81

- transport_plan: ndarray

82

L1-L2 regularized transport plan.

83

- log: dict (if log=True)

84

"""

85

86

def ot.da.emd_laplace(a, labels_a, b, M, eta=0.1, numItermax=10, verbose=False, log=False):

87

"""

88

Solve optimal transport with Laplacian regularization.

89

90

Uses Laplacian regularization to enforce smooth transport plans that

91

respect the local structure of the data manifold.

92

93

Parameters:

94

- a: array-like, shape (n_samples_source,)

95

Source distribution.

96

- labels_a: array-like, shape (n_samples_source,)

97

Source labels for constructing Laplacian.

98

- b: array-like, shape (n_samples_target,)

99

Target distribution.

100

- M: array-like, shape (n_samples_source, n_samples_target)

101

Cost matrix.

102

- eta: float, default=0.1

103

Laplacian regularization parameter.

104

- numItermax: int, default=10

105

Maximum iterations.

106

- verbose: bool, default=False

107

- log: bool, default=False

108

109

Returns:

110

- transport_plan: ndarray

111

Laplacian-regularized transport plan.

112

- log: dict (if log=True)

113

"""

114

115

def ot.da.distribution_estimation_uniform(X):

116

"""

117

Estimate uniform distribution over samples.

118

119

Simple utility to create uniform weights for samples when no

120

prior distribution information is available.

121

122

Parameters:

123

- X: array-like, shape (n_samples, n_features)

124

Input samples.

125

126

Returns:

127

- distribution: ndarray, shape (n_samples,)

128

Uniform distribution (each entry equals 1/n_samples).

129

"""

130

```

131

132

## Transport Classes for Domain Adaptation

133

134

### Base Transport Class

135

136

```python { .api }

137

class ot.da.BaseTransport:

138

"""

139

Base class for optimal transport-based domain adaptation.

140

141

Provides common interface and functionality for all transport-based

142

domain adaptation methods.

143

144

Parameters:

145

- log: bool, default=False

146

Whether to store optimization logs.

147

- verbose: bool, default=False

148

Print information during fitting.

149

- out_of_sample_map: str, default='ferradans'

150

Out-of-sample mapping method for new data points.

151

"""

152

153

def fit(self, Xs=None, Xt=None, ys=None, yt=None):

154

"""

155

Build a coupling matrix from source and target sets.

156

157

Parameters:

158

- Xs: array-like, shape (n_source_samples, n_features)

159

Source domain samples.

160

- Xt: array-like, shape (n_target_samples, n_features)

161

Target domain samples.

162

- ys: array-like, shape (n_source_samples,), optional

163

Source domain labels.

164

- yt: array-like, shape (n_target_samples,), optional

165

Target domain labels.

166

167

Returns:

168

- self: BaseTransport instance

169

"""

170

171

def transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):

172

"""

173

Transform source samples to target domain.

174

175

Parameters:

176

- Xs: array-like, shape (n_samples, n_features), optional

177

Source samples to transform.

178

- Xt: array-like, shape (n_samples, n_features), optional

179

Target samples to inverse transform.

180

- ys: array-like, optional

181

Source labels.

182

- yt: array-like, optional

183

Target labels.

184

- batch_size: int, default=128

185

Batch size for large-scale transformations.

186

187

Returns:

188

- transformed_samples: ndarray

189

Samples transformed to target domain.

190

"""

191

192

def transform_labels(self, ys=None):

193

"""

194

Propagate source labels to target domain.

195

196

Parameters:

197

- ys: array-like, shape (n_source_samples,)

198

Source labels to propagate.

199

200

Returns:

201

- target_labels: ndarray, shape (n_target_samples,)

202

Labels assigned to target samples.

203

"""

204

205

def inverse_transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):

206

"""

207

Transform target samples to source domain.

208

209

Parameters: Similar to transform()

210

211

Returns:

212

- inverse_transformed: ndarray

213

Target samples transformed to source domain.

214

"""

215

```

216

217

### Linear Transport Methods

218

219

```python { .api }

220

class ot.da.LinearTransport(BaseTransport):

221

"""

222

Linear optimal transport for domain adaptation.

223

224

Learns a linear transformation matrix for mapping between domains

225

based on optimal transport theory.

226

227

Parameters:

228

- reg: float, default=1e-8

229

Regularization parameter for matrix inversion.

230

- bias: bool, default=False

231

Whether to estimate bias term.

232

- log: bool, default=False

233

- verbose: bool, default=False

234

"""

235

236

class ot.da.LinearGWTransport(BaseTransport):

237

"""

238

Linear Gromov-Wasserstein transport for domain adaptation.

239

240

Uses Gromov-Wasserstein distance to handle domains with different

241

feature spaces by comparing internal structure rather than features directly.

242

243

Parameters:

244

- reg: float, default=1e-8

245

Regularization parameter.

246

- alpha: float, default=0.5

247

GW optimization step size.

248

- max_iter: int, default=100

249

Maximum GW iterations.

250

- tol: float, default=1e-6

251

GW convergence tolerance.

252

"""

253

```

254

255

### Sinkhorn-based Transport

256

257

```python { .api }

258

class ot.da.SinkhornTransport(BaseTransport):

259

"""

260

Sinkhorn transport for domain adaptation.

261

262

Uses entropic regularization and Sinkhorn algorithm for efficient

263

computation of transport plans between domains.

264

265

Parameters:

266

- reg_e: float, default=1.0

267

Entropic regularization parameter.

268

- max_iter: int, default=1000

269

Maximum Sinkhorn iterations.

270

- tol: float, default=1e-9

271

Sinkhorn convergence tolerance.

272

- verbose: bool, default=False

273

- log: bool, default=False

274

- metric: str, default='sqeuclidean'

275

Ground metric for cost matrix computation.

276

- norm: str, optional

277

Cost matrix normalization method.

278

- distribution_estimation: callable, default=distribution_estimation_uniform

279

Method for estimating sample distributions.

280

- out_of_sample_map: str, default='ferradans'

281

Out-of-sample mapping technique.

282

- limit_max: float, default=np.infty

283

Maximum value for cost matrix entries.

284

"""

285

286

class ot.da.EMDTransport(BaseTransport):

287

"""

288

Exact EMD transport for domain adaptation.

289

290

Uses exact optimal transport (Earth Mover's Distance) without

291

regularization for precise domain adaptation.

292

293

Parameters:

294

- metric: str, default='sqeuclidean'

295

Ground metric for cost computation.

296

- norm: str, optional

297

Cost normalization method.

298

- log: bool, default=False

299

- verbose: bool, default=False

300

- distribution_estimation: callable, default=distribution_estimation_uniform

301

- out_of_sample_map: str, default='ferradans'

302

- limit_max: float, default=np.infty

303

Cost matrix entry limit.

304

"""

305

```

306

307

### Label-Regularized Transport Classes

308

309

```python { .api }

310

class ot.da.SinkhornLpl1Transport(BaseTransport):

311

"""

312

Sinkhorn transport with L1 label regularization.

313

314

Incorporates source domain labels to guide transport using L1 penalty

315

on cross-class transport.

316

317

Parameters:

318

- reg_e: float, default=1.0

319

Entropic regularization.

320

- reg_cl: float, default=0.1

321

Label regularization parameter.

322

- max_iter: int, default=10

323

Maximum outer iterations.

324

- max_inner_iter: int, default=200

325

Maximum inner Sinkhorn iterations.

326

- log: bool, default=False

327

- verbose: bool, default=False

328

- metric: str, default='sqeuclidean'

329

"""

330

331

class ot.da.SinkhornL1l2Transport(BaseTransport):

332

"""

333

Sinkhorn transport with L1-L2 group lasso regularization.

334

335

Combines L1 sparsity with L2 group penalties for structured

336

domain adaptation.

337

338

Parameters:

339

- reg_e: float, default=1.0

340

Entropic regularization.

341

- reg_cl: float, default=0.1

342

L1 regularization.

343

- reg_l: float, default=0.1

344

L2 group regularization.

345

- max_iter: int, default=10

346

- max_inner_iter: int, default=200

347

- tol: float, default=1e-9

348

"""

349

350

class ot.da.EMDLaplaceTransport(BaseTransport):

351

"""

352

EMD transport with Laplacian regularization.

353

354

Uses Laplacian penalty to ensure smooth transport respecting

355

data manifold structure.

356

357

Parameters:

358

- reg_lap: float, default=1.0

359

Laplacian regularization parameter.

360

- reg_src: float, default=0.5

361

Source regularization.

362

- metric: str, default='sqeuclidean'

363

- norm: str, optional

364

- similarity: str, default='knn'

365

Method for similarity matrix construction.

366

- similarity_param: int, default=7

367

Parameter for similarity computation (e.g., k for knn).

368

- max_iter: int, default=10

369

"""

370

```

371

372

### Advanced Transport Methods

373

374

```python { .api }

375

class ot.da.MappingTransport(BaseTransport):

376

"""

377

Optimal transport with learned mappings.

378

379

Learns parametric mappings (linear or kernel-based) that approximate

380

the optimal transport map.

381

382

Parameters:

383

- mu: float, default=1e0

384

Regularization parameter for mapping learning.

385

- eta: float, default=1e-8

386

Numerical regularization.

387

- bias: bool, default=True

388

Whether to learn bias terms.

389

- metric: str, default='sqeuclidean'

390

- norm: str, optional

391

- kernel: str, default='linear'

392

Kernel type ('linear', 'gaussian', 'rbf').

393

- sigma: float, default=1.0

394

Kernel bandwidth (for Gaussian/RBF kernels).

395

- max_iter: int, default=100

396

- tol: float, default=1e-5

397

- max_inner_iter: int, default=10

398

- inner_tol: float, default=1e-6

399

- log: bool, default=False

400

- verbose: bool, default=False

401

- verbose2: bool, default=False

402

"""

403

404

class ot.da.UnbalancedSinkhornTransport(BaseTransport):

405

"""

406

Unbalanced Sinkhorn transport for domain adaptation.

407

408

Handles domain adaptation with different marginal distributions

409

using unbalanced optimal transport.

410

411

Parameters:

412

- reg_e: float, default=1.0

413

Entropic regularization.

414

- reg_m: float, default=1.0

415

Marginal relaxation parameter.

416

- method: str, default='sinkhorn'

417

Unbalanced algorithm variant.

418

- max_iter: int, default=1000

419

- tol: float, default=1e-9

420

- verbose: bool, default=False

421

- log: bool, default=False

422

"""

423

424

class ot.da.JCPOTTransport(BaseTransport):

425

"""

426

Joint Characteristic-Optimal-Transport (JCPOT) for multi-source adaptation.

427

428

Handles multiple source domains simultaneously using joint optimal

429

transport formulation.

430

431

Parameters:

432

- reg_e: float, default=1.0

433

Entropic regularization.

434

- max_iter: int, default=10

435

- tol: float, default=1e-6

436

- verbose: bool, default=False

437

- log: bool, default=False

438

- metric: str, default='sqeuclidean'

439

"""

440

441

class ot.da.NearestBrenierPotential(BaseTransport):

442

"""

443

Transport using nearest Brenier potential approximation.

444

445

Learns optimal transport maps through Brenier potential estimation

446

for smooth and invertible domain adaptation.

447

448

Parameters:

449

- reg: float, default=1e-3

450

Regularization for potential learning.

451

- max_iter: int, default=100

452

- tol: float, default=1e-6

453

"""

454

```

455

456

## Usage Examples

457

458

### Basic Domain Adaptation

459

```python

460

import ot

461

import numpy as np

462

from sklearn.datasets import make_classification

463

464

# Generate source and target domains

465

n_source, n_target = 150, 100

466

n_features = 2

467

468

# Source domain

469

Xs, ys = make_classification(n_samples=n_source, n_features=n_features,

470

n_redundant=0, n_informative=2,

471

random_state=1, n_clusters_per_class=1)

472

473

# Target domain (shifted and rotated)

474

Xt, yt = make_classification(n_samples=n_target, n_features=n_features,

475

n_redundant=0, n_informative=2,

476

random_state=42, n_clusters_per_class=1)

477

478

# Apply domain shift

479

angle = np.pi / 6

480

rotation = np.array([[np.cos(angle), -np.sin(angle)],

481

[np.sin(angle), np.cos(angle)]])

482

Xt = Xt @ rotation + [1, 1]

483

484

print(f"Source domain shape: {Xs.shape}")

485

print(f"Target domain shape: {Xt.shape}")

486

```

487

488

### Sinkhorn Transport Adaptation

489

```python

490

# Initialize Sinkhorn transport

491

sinkhorn_adapter = ot.da.SinkhornTransport(reg_e=0.1, verbose=True)

492

493

# Fit the transport

494

sinkhorn_adapter.fit(Xs=Xs, Xt=Xt)

495

496

# Transform source samples to target domain

497

Xs_adapted = sinkhorn_adapter.transform(Xs=Xs)

498

499

print("Adaptation completed")

500

print(f"Adapted source shape: {Xs_adapted.shape}")

501

print(f"Transport cost: {sinkhorn_adapter.coupling_.sum()}")

502

```

503

504

### Label-Regularized Adaptation

505

```python

506

# Use source labels for better adaptation

507

label_adapter = ot.da.SinkhornLpl1Transport(

508

reg_e=0.1, reg_cl=0.1, verbose=True

509

)

510

511

# Fit with source labels

512

label_adapter.fit(Xs=Xs, ys=ys, Xt=Xt)

513

514

# Transform and propagate labels

515

Xs_label_adapted = label_adapter.transform(Xs=Xs)

516

yt_predicted = label_adapter.transform_labels(ys=ys)

517

518

print(f"Label-adapted source shape: {Xs_label_adapted.shape}")

519

print(f"Predicted target labels shape: {yt_predicted.shape}")

520

```

521

522

### Multi-Method Comparison

523

```python

524

# Compare different adaptation methods

525

methods = {

526

'EMD': ot.da.EMDTransport(),

527

'Sinkhorn': ot.da.SinkhornTransport(reg_e=0.1),

528

'Linear': ot.da.LinearTransport(),

529

'Unbalanced': ot.da.UnbalancedSinkhornTransport(reg_e=0.1, reg_m=1.0)

530

}

531

532

adapted_sources = {}

533

534

for name, method in methods.items():

535

print(f"\nFitting {name} transport...")

536

method.fit(Xs=Xs, Xt=Xt)

537

adapted_sources[name] = method.transform(Xs=Xs)

538

539

# Compute adaptation quality (distance to target centroid)

540

target_center = np.mean(Xt, axis=0)

541

adapted_center = np.mean(adapted_sources[name], axis=0)

542

distance = np.linalg.norm(target_center - adapted_center)

543

print(f"{name} - Distance to target center: {distance:.4f}")

544

```

545

546

### Out-of-Sample Adaptation

547

```python

548

# Generate new source samples for out-of-sample testing

549

Xs_new = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], 50)

550

551

# Adapt new samples using trained transport

552

sinkhorn_adapter = ot.da.SinkhornTransport(reg_e=0.1)

553

sinkhorn_adapter.fit(Xs=Xs, Xt=Xt)

554

555

# Transform new samples

556

Xs_new_adapted = sinkhorn_adapter.transform(Xs=Xs_new)

557

558

print(f"New source samples: {Xs_new.shape}")

559

print(f"Adapted new samples: {Xs_new_adapted.shape}")

560

```

561

562

### JCPOT Multi-Source Adaptation

563

```python

564

# Create multiple source domains

565

n_sources = 3

566

source_domains = []

567

source_labels = []

568

569

for i in range(n_sources):

570

Xs_i, ys_i = make_classification(n_samples=100, n_features=2,

571

random_state=i, n_clusters_per_class=1)

572

# Apply different shifts to each source

573

Xs_i = Xs_i + [i*0.5, i*0.3]

574

source_domains.append(Xs_i)

575

source_labels.append(ys_i)

576

577

# JCPOT adaptation

578

jcpot_adapter = ot.da.JCPOTTransport(reg_e=0.1, verbose=True)

579

580

# Fit multiple sources to single target

581

jcpot_adapter.fit(Xs=source_domains, ys=source_labels, Xt=Xt, yt=yt)

582

583

print("JCPOT multi-source adaptation completed")

584

```

585

586

### Advanced Mapping Transport

587

```python

588

# Use mapping transport with RBF kernel

589

mapping_adapter = ot.da.MappingTransport(

590

kernel='rbf', sigma=1.0, mu=1e-1, verbose=True

591

)

592

593

mapping_adapter.fit(Xs=Xs, Xt=Xt)

594

Xs_mapped = mapping_adapter.transform(Xs=Xs)

595

596

print("Mapping transport with RBF kernel completed")

597

598

# The learned mapping can be applied to new data

599

Xs_new_mapped = mapping_adapter.transform(Xs=Xs_new)

600

print(f"New samples mapped: {Xs_new_mapped.shape}")

601

```

602

603

### Performance Evaluation

604

```python

605

from sklearn.neighbors import KNeighborsClassifier

606

from sklearn.metrics import accuracy_score

607

608

# Train classifier on adapted source data

609

knn = KNeighborsClassifier(n_neighbors=3)

610

611

# Test different adaptations

612

results = {}

613

614

for name, Xs_adapted in adapted_sources.items():

615

# Train on adapted source

616

knn.fit(Xs_adapted, ys)

617

618

# Predict on target (when labels available)

619

if len(np.unique(yt)) > 1: # Check if target has multiple classes

620

yt_pred = knn.predict(Xt)

621

accuracy = accuracy_score(yt, yt_pred)

622

results[name] = accuracy

623

print(f"{name} adaptation accuracy: {accuracy:.3f}")

624

625

# Baseline: no adaptation

626

knn.fit(Xs, ys)

627

if len(np.unique(yt)) > 1:

628

yt_pred_baseline = knn.predict(Xt)

629

baseline_acc = accuracy_score(yt, yt_pred_baseline)

630

print(f"No adaptation accuracy: {baseline_acc:.3f}")

631

```

632

633

## Applications

634

635

### Computer Vision

636

- **Cross-dataset adaptation**: Adapting models trained on one image dataset to another

637

- **Domain shift**: Handling changes in lighting, camera, or image style

638

- **Synthetic-to-real**: Adapting from synthetic training data to real images

639

640

### Natural Language Processing

641

- **Cross-lingual adaptation**: Transferring models between languages

642

- **Domain-specific text**: Adapting from general to domain-specific corpora

643

- **Temporal adaptation**: Handling language evolution over time

644

645

### Biomedical Applications

646

- **Cross-study adaptation**: Adapting between different clinical studies

647

- **Multi-site data**: Handling batch effects across research sites

648

- **Cross-species**: Transferring knowledge between related organisms

649

650

The `ot.da` module provides comprehensive tools for transport-based domain adaptation, offering both theoretical rigor and practical effectiveness for bridging distribution gaps in machine learning applications.