docs
0
# Domain Adaptation
1
2
The `ot.da` module provides transport-based methods for domain adaptation in machine learning. These algorithms learn mappings between different domains (e.g., training and test distributions) by leveraging optimal transport theory, enabling knowledge transfer when source and target domains differ.
3
4
## Core Domain Adaptation Functions
5
6
### Label-Regularized Transport
7
8
```python { .api }
9
def ot.da.sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
10
"""
11
Solve optimal transport with L1 label regularization using MM algorithm.
12
13
Incorporates label information from the source domain to guide the transport
14
by penalizing transport between samples with different labels. Uses the
15
Majorization-Minimization (MM) algorithmic framework.
16
17
Parameters:
18
- a: array-like, shape (n_samples_source,)
19
Source domain distribution (weights of source samples).
20
- labels_a: array-like, shape (n_samples_source,)
21
Labels of source domain samples (integer class labels).
22
- b: array-like, shape (n_samples_target,)
23
Target domain distribution (weights of target samples).
24
- M: array-like, shape (n_samples_source, n_samples_target)
25
Ground cost matrix between source and target samples.
26
- reg: float
27
Entropic regularization parameter for Sinkhorn algorithm.
28
- eta: float, default=0.1
29
Label regularization parameter. Higher values enforce stronger
30
alignment between samples of the same class.
31
- numItermax: int, default=10
32
Maximum number of outer MM iterations.
33
- numInnerItermax: int, default=200
34
Maximum iterations for inner Sinkhorn algorithm.
35
- stopInnerThr: float, default=1e-9
36
Convergence threshold for inner Sinkhorn iterations.
37
- verbose: bool, default=False
38
Print iteration information.
39
- log: bool, default=False
40
Return optimization log with convergence details.
41
42
Returns:
43
- transport_plan: ndarray, shape (n_samples_source, n_samples_target)
44
Optimal transport plan with label regularization.
45
- log: dict (if log=True)
46
Contains 'err': convergence errors, 'all_err': all errors history.
47
"""
48
49
def ot.da.sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, alpha=0.98, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
50
"""
51
Solve optimal transport with L1-L2 group lasso regularization.
52
53
Combines L1 sparsity regularization with L2 group lasso to encourage
54
both sparsity and grouping in the transport plan according to class labels.
55
56
Parameters:
57
- a: array-like, shape (n_samples_source,)
58
Source distribution.
59
- labels_a: array-like, shape (n_samples_source,)
60
Source labels.
61
- b: array-like, shape (n_samples_target,)
62
Target distribution.
63
- M: array-like, shape (n_samples_source, n_samples_target)
64
Cost matrix.
65
- reg: float
66
Entropic regularization parameter.
67
- eta: float, default=0.1
68
L1 regularization parameter.
69
- alpha: float, default=0.98
70
Trade-off between L1 and L2 regularization (elastic net parameter).
71
- numItermax: int, default=10
72
Maximum outer iterations.
73
- numInnerItermax: int, default=200
74
Maximum inner iterations.
75
- stopInnerThr: float, default=1e-9
76
Inner convergence threshold.
77
- verbose: bool, default=False
78
- log: bool, default=False
79
80
Returns:
81
- transport_plan: ndarray
82
L1-L2 regularized transport plan.
83
- log: dict (if log=True)
84
"""
85
86
def ot.da.emd_laplace(a, labels_a, b, M, eta=0.1, numItermax=10, verbose=False, log=False):
87
"""
88
Solve optimal transport with Laplacian regularization.
89
90
Uses Laplacian regularization to enforce smooth transport plans that
91
respect the local structure of the data manifold.
92
93
Parameters:
94
- a: array-like, shape (n_samples_source,)
95
Source distribution.
96
- labels_a: array-like, shape (n_samples_source,)
97
Source labels for constructing Laplacian.
98
- b: array-like, shape (n_samples_target,)
99
Target distribution.
100
- M: array-like, shape (n_samples_source, n_samples_target)
101
Cost matrix.
102
- eta: float, default=0.1
103
Laplacian regularization parameter.
104
- numItermax: int, default=10
105
Maximum iterations.
106
- verbose: bool, default=False
107
- log: bool, default=False
108
109
Returns:
110
- transport_plan: ndarray
111
Laplacian-regularized transport plan.
112
- log: dict (if log=True)
113
"""
114
115
def ot.da.distribution_estimation_uniform(X):
116
"""
117
Estimate uniform distribution over samples.
118
119
Simple utility to create uniform weights for samples when no
120
prior distribution information is available.
121
122
Parameters:
123
- X: array-like, shape (n_samples, n_features)
124
Input samples.
125
126
Returns:
127
- distribution: ndarray, shape (n_samples,)
128
Uniform distribution (each entry equals 1/n_samples).
129
"""
130
```
131
132
## Transport Classes for Domain Adaptation
133
134
### Base Transport Class
135
136
```python { .api }
137
class ot.da.BaseTransport:
138
"""
139
Base class for optimal transport-based domain adaptation.
140
141
Provides common interface and functionality for all transport-based
142
domain adaptation methods.
143
144
Parameters:
145
- log: bool, default=False
146
Whether to store optimization logs.
147
- verbose: bool, default=False
148
Print information during fitting.
149
- out_of_sample_map: str, default='ferradans'
150
Out-of-sample mapping method for new data points.
151
"""
152
153
def fit(self, Xs=None, Xt=None, ys=None, yt=None):
154
"""
155
Build a coupling matrix from source and target sets.
156
157
Parameters:
158
- Xs: array-like, shape (n_source_samples, n_features)
159
Source domain samples.
160
- Xt: array-like, shape (n_target_samples, n_features)
161
Target domain samples.
162
- ys: array-like, shape (n_source_samples,), optional
163
Source domain labels.
164
- yt: array-like, shape (n_target_samples,), optional
165
Target domain labels.
166
167
Returns:
168
- self: BaseTransport instance
169
"""
170
171
def transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):
172
"""
173
Transform source samples to target domain.
174
175
Parameters:
176
- Xs: array-like, shape (n_samples, n_features), optional
177
Source samples to transform.
178
- Xt: array-like, shape (n_samples, n_features), optional
179
Target samples to inverse transform.
180
- ys: array-like, optional
181
Source labels.
182
- yt: array-like, optional
183
Target labels.
184
- batch_size: int, default=128
185
Batch size for large-scale transformations.
186
187
Returns:
188
- transformed_samples: ndarray
189
Samples transformed to target domain.
190
"""
191
192
def transform_labels(self, ys=None):
193
"""
194
Propagate source labels to target domain.
195
196
Parameters:
197
- ys: array-like, shape (n_source_samples,)
198
Source labels to propagate.
199
200
Returns:
201
- target_labels: ndarray, shape (n_target_samples,)
202
Labels assigned to target samples.
203
"""
204
205
def inverse_transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):
206
"""
207
Transform target samples to source domain.
208
209
Parameters: Similar to transform()
210
211
Returns:
212
- inverse_transformed: ndarray
213
Target samples transformed to source domain.
214
"""
215
```
216
217
### Linear Transport Methods
218
219
```python { .api }
220
class ot.da.LinearTransport(BaseTransport):
221
"""
222
Linear optimal transport for domain adaptation.
223
224
Learns a linear transformation matrix for mapping between domains
225
based on optimal transport theory.
226
227
Parameters:
228
- reg: float, default=1e-8
229
Regularization parameter for matrix inversion.
230
- bias: bool, default=False
231
Whether to estimate bias term.
232
- log: bool, default=False
233
- verbose: bool, default=False
234
"""
235
236
class ot.da.LinearGWTransport(BaseTransport):
237
"""
238
Linear Gromov-Wasserstein transport for domain adaptation.
239
240
Uses Gromov-Wasserstein distance to handle domains with different
241
feature spaces by comparing internal structure rather than features directly.
242
243
Parameters:
244
- reg: float, default=1e-8
245
Regularization parameter.
246
- alpha: float, default=0.5
247
GW optimization step size.
248
- max_iter: int, default=100
249
Maximum GW iterations.
250
- tol: float, default=1e-6
251
GW convergence tolerance.
252
"""
253
```
254
255
### Sinkhorn-based Transport
256
257
```python { .api }
258
class ot.da.SinkhornTransport(BaseTransport):
259
"""
260
Sinkhorn transport for domain adaptation.
261
262
Uses entropic regularization and Sinkhorn algorithm for efficient
263
computation of transport plans between domains.
264
265
Parameters:
266
- reg_e: float, default=1.0
267
Entropic regularization parameter.
268
- max_iter: int, default=1000
269
Maximum Sinkhorn iterations.
270
- tol: float, default=1e-9
271
Sinkhorn convergence tolerance.
272
- verbose: bool, default=False
273
- log: bool, default=False
274
- metric: str, default='sqeuclidean'
275
Ground metric for cost matrix computation.
276
- norm: str, optional
277
Cost matrix normalization method.
278
- distribution_estimation: callable, default=distribution_estimation_uniform
279
Method for estimating sample distributions.
280
- out_of_sample_map: str, default='ferradans'
281
Out-of-sample mapping technique.
282
- limit_max: float, default=np.infty
283
Maximum value for cost matrix entries.
284
"""
285
286
class ot.da.EMDTransport(BaseTransport):
287
"""
288
Exact EMD transport for domain adaptation.
289
290
Uses exact optimal transport (Earth Mover's Distance) without
291
regularization for precise domain adaptation.
292
293
Parameters:
294
- metric: str, default='sqeuclidean'
295
Ground metric for cost computation.
296
- norm: str, optional
297
Cost normalization method.
298
- log: bool, default=False
299
- verbose: bool, default=False
300
- distribution_estimation: callable, default=distribution_estimation_uniform
301
- out_of_sample_map: str, default='ferradans'
302
- limit_max: float, default=np.infty
303
Cost matrix entry limit.
304
"""
305
```
306
307
### Label-Regularized Transport Classes
308
309
```python { .api }
310
class ot.da.SinkhornLpl1Transport(BaseTransport):
311
"""
312
Sinkhorn transport with L1 label regularization.
313
314
Incorporates source domain labels to guide transport using L1 penalty
315
on cross-class transport.
316
317
Parameters:
318
- reg_e: float, default=1.0
319
Entropic regularization.
320
- reg_cl: float, default=0.1
321
Label regularization parameter.
322
- max_iter: int, default=10
323
Maximum outer iterations.
324
- max_inner_iter: int, default=200
325
Maximum inner Sinkhorn iterations.
326
- log: bool, default=False
327
- verbose: bool, default=False
328
- metric: str, default='sqeuclidean'
329
"""
330
331
class ot.da.SinkhornL1l2Transport(BaseTransport):
332
"""
333
Sinkhorn transport with L1-L2 group lasso regularization.
334
335
Combines L1 sparsity with L2 group penalties for structured
336
domain adaptation.
337
338
Parameters:
339
- reg_e: float, default=1.0
340
Entropic regularization.
341
- reg_cl: float, default=0.1
342
L1 regularization.
343
- reg_l: float, default=0.1
344
L2 group regularization.
345
- max_iter: int, default=10
346
- max_inner_iter: int, default=200
347
- tol: float, default=1e-9
348
"""
349
350
class ot.da.EMDLaplaceTransport(BaseTransport):
351
"""
352
EMD transport with Laplacian regularization.
353
354
Uses Laplacian penalty to ensure smooth transport respecting
355
data manifold structure.
356
357
Parameters:
358
- reg_lap: float, default=1.0
359
Laplacian regularization parameter.
360
- reg_src: float, default=0.5
361
Source regularization.
362
- metric: str, default='sqeuclidean'
363
- norm: str, optional
364
- similarity: str, default='knn'
365
Method for similarity matrix construction.
366
- similarity_param: int, default=7
367
Parameter for similarity computation (e.g., k for knn).
368
- max_iter: int, default=10
369
"""
370
```
371
372
### Advanced Transport Methods
373
374
```python { .api }
375
class ot.da.MappingTransport(BaseTransport):
376
"""
377
Optimal transport with learned mappings.
378
379
Learns parametric mappings (linear or kernel-based) that approximate
380
the optimal transport map.
381
382
Parameters:
383
- mu: float, default=1e0
384
Regularization parameter for mapping learning.
385
- eta: float, default=1e-8
386
Numerical regularization.
387
- bias: bool, default=True
388
Whether to learn bias terms.
389
- metric: str, default='sqeuclidean'
390
- norm: str, optional
391
- kernel: str, default='linear'
392
Kernel type ('linear', 'gaussian', 'rbf').
393
- sigma: float, default=1.0
394
Kernel bandwidth (for Gaussian/RBF kernels).
395
- max_iter: int, default=100
396
- tol: float, default=1e-5
397
- max_inner_iter: int, default=10
398
- inner_tol: float, default=1e-6
399
- log: bool, default=False
400
- verbose: bool, default=False
401
- verbose2: bool, default=False
402
"""
403
404
class ot.da.UnbalancedSinkhornTransport(BaseTransport):
405
"""
406
Unbalanced Sinkhorn transport for domain adaptation.
407
408
Handles domain adaptation with different marginal distributions
409
using unbalanced optimal transport.
410
411
Parameters:
412
- reg_e: float, default=1.0
413
Entropic regularization.
414
- reg_m: float, default=1.0
415
Marginal relaxation parameter.
416
- method: str, default='sinkhorn'
417
Unbalanced algorithm variant.
418
- max_iter: int, default=1000
419
- tol: float, default=1e-9
420
- verbose: bool, default=False
421
- log: bool, default=False
422
"""
423
424
class ot.da.JCPOTTransport(BaseTransport):
425
"""
426
Joint Characteristic-Optimal-Transport (JCPOT) for multi-source adaptation.
427
428
Handles multiple source domains simultaneously using joint optimal
429
transport formulation.
430
431
Parameters:
432
- reg_e: float, default=1.0
433
Entropic regularization.
434
- max_iter: int, default=10
435
- tol: float, default=1e-6
436
- verbose: bool, default=False
437
- log: bool, default=False
438
- metric: str, default='sqeuclidean'
439
"""
440
441
class ot.da.NearestBrenierPotential(BaseTransport):
442
"""
443
Transport using nearest Brenier potential approximation.
444
445
Learns optimal transport maps through Brenier potential estimation
446
for smooth and invertible domain adaptation.
447
448
Parameters:
449
- reg: float, default=1e-3
450
Regularization for potential learning.
451
- max_iter: int, default=100
452
- tol: float, default=1e-6
453
"""
454
```
455
456
## Usage Examples
457
458
### Basic Domain Adaptation
459
```python
460
import ot
461
import numpy as np
462
from sklearn.datasets import make_classification
463
464
# Generate source and target domains
465
n_source, n_target = 150, 100
466
n_features = 2
467
468
# Source domain
469
Xs, ys = make_classification(n_samples=n_source, n_features=n_features,
470
n_redundant=0, n_informative=2,
471
random_state=1, n_clusters_per_class=1)
472
473
# Target domain (shifted and rotated)
474
Xt, yt = make_classification(n_samples=n_target, n_features=n_features,
475
n_redundant=0, n_informative=2,
476
random_state=42, n_clusters_per_class=1)
477
478
# Apply domain shift
479
angle = np.pi / 6
480
rotation = np.array([[np.cos(angle), -np.sin(angle)],
481
[np.sin(angle), np.cos(angle)]])
482
Xt = Xt @ rotation + [1, 1]
483
484
print(f"Source domain shape: {Xs.shape}")
485
print(f"Target domain shape: {Xt.shape}")
486
```
487
488
### Sinkhorn Transport Adaptation
489
```python
490
# Initialize Sinkhorn transport
491
sinkhorn_adapter = ot.da.SinkhornTransport(reg_e=0.1, verbose=True)
492
493
# Fit the transport
494
sinkhorn_adapter.fit(Xs=Xs, Xt=Xt)
495
496
# Transform source samples to target domain
497
Xs_adapted = sinkhorn_adapter.transform(Xs=Xs)
498
499
print("Adaptation completed")
500
print(f"Adapted source shape: {Xs_adapted.shape}")
501
print(f"Transport cost: {sinkhorn_adapter.coupling_.sum()}")
502
```
503
504
### Label-Regularized Adaptation
505
```python
506
# Use source labels for better adaptation
507
label_adapter = ot.da.SinkhornLpl1Transport(
508
reg_e=0.1, reg_cl=0.1, verbose=True
509
)
510
511
# Fit with source labels
512
label_adapter.fit(Xs=Xs, ys=ys, Xt=Xt)
513
514
# Transform and propagate labels
515
Xs_label_adapted = label_adapter.transform(Xs=Xs)
516
yt_predicted = label_adapter.transform_labels(ys=ys)
517
518
print(f"Label-adapted source shape: {Xs_label_adapted.shape}")
519
print(f"Predicted target labels shape: {yt_predicted.shape}")
520
```
521
522
### Multi-Method Comparison
523
```python
524
# Compare different adaptation methods
525
methods = {
526
'EMD': ot.da.EMDTransport(),
527
'Sinkhorn': ot.da.SinkhornTransport(reg_e=0.1),
528
'Linear': ot.da.LinearTransport(),
529
'Unbalanced': ot.da.UnbalancedSinkhornTransport(reg_e=0.1, reg_m=1.0)
530
}
531
532
adapted_sources = {}
533
534
for name, method in methods.items():
535
print(f"\nFitting {name} transport...")
536
method.fit(Xs=Xs, Xt=Xt)
537
adapted_sources[name] = method.transform(Xs=Xs)
538
539
# Compute adaptation quality (distance to target centroid)
540
target_center = np.mean(Xt, axis=0)
541
adapted_center = np.mean(adapted_sources[name], axis=0)
542
distance = np.linalg.norm(target_center - adapted_center)
543
print(f"{name} - Distance to target center: {distance:.4f}")
544
```
545
546
### Out-of-Sample Adaptation
547
```python
548
# Generate new source samples for out-of-sample testing
549
Xs_new = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], 50)
550
551
# Adapt new samples using trained transport
552
sinkhorn_adapter = ot.da.SinkhornTransport(reg_e=0.1)
553
sinkhorn_adapter.fit(Xs=Xs, Xt=Xt)
554
555
# Transform new samples
556
Xs_new_adapted = sinkhorn_adapter.transform(Xs=Xs_new)
557
558
print(f"New source samples: {Xs_new.shape}")
559
print(f"Adapted new samples: {Xs_new_adapted.shape}")
560
```
561
562
### JCPOT Multi-Source Adaptation
563
```python
564
# Create multiple source domains
565
n_sources = 3
566
source_domains = []
567
source_labels = []
568
569
for i in range(n_sources):
570
Xs_i, ys_i = make_classification(n_samples=100, n_features=2,
571
random_state=i, n_clusters_per_class=1)
572
# Apply different shifts to each source
573
Xs_i = Xs_i + [i*0.5, i*0.3]
574
source_domains.append(Xs_i)
575
source_labels.append(ys_i)
576
577
# JCPOT adaptation
578
jcpot_adapter = ot.da.JCPOTTransport(reg_e=0.1, verbose=True)
579
580
# Fit multiple sources to single target
581
jcpot_adapter.fit(Xs=source_domains, ys=source_labels, Xt=Xt, yt=yt)
582
583
print("JCPOT multi-source adaptation completed")
584
```
585
586
### Advanced Mapping Transport
587
```python
588
# Use mapping transport with RBF kernel
589
mapping_adapter = ot.da.MappingTransport(
590
kernel='rbf', sigma=1.0, mu=1e-1, verbose=True
591
)
592
593
mapping_adapter.fit(Xs=Xs, Xt=Xt)
594
Xs_mapped = mapping_adapter.transform(Xs=Xs)
595
596
print("Mapping transport with RBF kernel completed")
597
598
# The learned mapping can be applied to new data
599
Xs_new_mapped = mapping_adapter.transform(Xs=Xs_new)
600
print(f"New samples mapped: {Xs_new_mapped.shape}")
601
```
602
603
### Performance Evaluation
604
```python
605
from sklearn.neighbors import KNeighborsClassifier
606
from sklearn.metrics import accuracy_score
607
608
# Train classifier on adapted source data
609
knn = KNeighborsClassifier(n_neighbors=3)
610
611
# Test different adaptations
612
results = {}
613
614
for name, Xs_adapted in adapted_sources.items():
615
# Train on adapted source
616
knn.fit(Xs_adapted, ys)
617
618
# Predict on target (when labels available)
619
if len(np.unique(yt)) > 1: # Check if target has multiple classes
620
yt_pred = knn.predict(Xt)
621
accuracy = accuracy_score(yt, yt_pred)
622
results[name] = accuracy
623
print(f"{name} adaptation accuracy: {accuracy:.3f}")
624
625
# Baseline: no adaptation
626
knn.fit(Xs, ys)
627
if len(np.unique(yt)) > 1:
628
yt_pred_baseline = knn.predict(Xt)
629
baseline_acc = accuracy_score(yt, yt_pred_baseline)
630
print(f"No adaptation accuracy: {baseline_acc:.3f}")
631
```
632
633
## Applications
634
635
### Computer Vision
636
- **Cross-dataset adaptation**: Adapting models trained on one image dataset to another
637
- **Domain shift**: Handling changes in lighting, camera, or image style
638
- **Synthetic-to-real**: Adapting from synthetic training data to real images
639
640
### Natural Language Processing
641
- **Cross-lingual adaptation**: Transferring models between languages
642
- **Domain-specific text**: Adapting from general to domain-specific corpora
643
- **Temporal adaptation**: Handling language evolution over time
644
645
### Biomedical Applications
646
- **Cross-study adaptation**: Adapting between different clinical studies
647
- **Multi-site data**: Handling batch effects across research sites
648
- **Cross-species**: Transferring knowledge between related organisms
649
650
The `ot.da` module provides comprehensive tools for transport-based domain adaptation, offering both theoretical rigor and practical effectiveness for bridging distribution gaps in machine learning applications.