CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

domain-adaptation.mddocs/

Domain Adaptation

The ot.da module provides transport-based methods for domain adaptation in machine learning. These algorithms learn mappings between different domains (e.g., training and test distributions) by leveraging optimal transport theory, enabling knowledge transfer when source and target domains differ.

Core Domain Adaptation Functions

Label-Regularized Transport

def ot.da.sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
    """
    Solve optimal transport with L1 label regularization using MM algorithm.
    
    Incorporates label information from the source domain to guide the transport
    by penalizing transport between samples with different labels. Uses the
    Majorization-Minimization (MM) algorithmic framework.
    
    Parameters:
    - a: array-like, shape (n_samples_source,)
         Source domain distribution (weights of source samples).
    - labels_a: array-like, shape (n_samples_source,)
         Labels of source domain samples (integer class labels).
    - b: array-like, shape (n_samples_target,)
         Target domain distribution (weights of target samples).
    - M: array-like, shape (n_samples_source, n_samples_target)
         Ground cost matrix between source and target samples.
    - reg: float
         Entropic regularization parameter for Sinkhorn algorithm.
    - eta: float, default=0.1
         Label regularization parameter. Higher values enforce stronger
         alignment between samples of the same class.
    - numItermax: int, default=10
         Maximum number of outer MM iterations.
    - numInnerItermax: int, default=200
         Maximum iterations for inner Sinkhorn algorithm.
    - stopInnerThr: float, default=1e-9
         Convergence threshold for inner Sinkhorn iterations.
    - verbose: bool, default=False
         Print iteration information.
    - log: bool, default=False
         Return optimization log with convergence details.
    
    Returns:
    - transport_plan: ndarray, shape (n_samples_source, n_samples_target)
         Optimal transport plan with label regularization.
    - log: dict (if log=True)
         Contains 'err': convergence errors, 'all_err': all errors history.
    """

def ot.da.sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, alpha=0.98, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
    """
    Solve optimal transport with L1-L2 group lasso regularization.
    
    Combines L1 sparsity regularization with L2 group lasso to encourage
    both sparsity and grouping in the transport plan according to class labels.
    
    Parameters:
    - a: array-like, shape (n_samples_source,)
         Source distribution.
    - labels_a: array-like, shape (n_samples_source,)
         Source labels.
    - b: array-like, shape (n_samples_target,)
         Target distribution.
    - M: array-like, shape (n_samples_source, n_samples_target)
         Cost matrix.
    - reg: float
         Entropic regularization parameter.
    - eta: float, default=0.1
         L1 regularization parameter.
    - alpha: float, default=0.98
         Trade-off between L1 and L2 regularization (elastic net parameter).
    - numItermax: int, default=10
         Maximum outer iterations.
    - numInnerItermax: int, default=200
         Maximum inner iterations.
    - stopInnerThr: float, default=1e-9
         Inner convergence threshold.
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray
         L1-L2 regularized transport plan.
    - log: dict (if log=True)
    """

def ot.da.emd_laplace(a, labels_a, b, M, eta=0.1, numItermax=10, verbose=False, log=False):
    """
    Solve optimal transport with Laplacian regularization.
    
    Uses Laplacian regularization to enforce smooth transport plans that
    respect the local structure of the data manifold.
    
    Parameters:
    - a: array-like, shape (n_samples_source,)
         Source distribution.
    - labels_a: array-like, shape (n_samples_source,)
         Source labels for constructing Laplacian.
    - b: array-like, shape (n_samples_target,)
         Target distribution.
    - M: array-like, shape (n_samples_source, n_samples_target)
         Cost matrix.
    - eta: float, default=0.1
         Laplacian regularization parameter.
    - numItermax: int, default=10
         Maximum iterations.
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray
         Laplacian-regularized transport plan.
    - log: dict (if log=True)
    """

def ot.da.distribution_estimation_uniform(X):
    """
    Estimate uniform distribution over samples.
    
    Simple utility to create uniform weights for samples when no
    prior distribution information is available.
    
    Parameters:
    - X: array-like, shape (n_samples, n_features)
         Input samples.
    
    Returns:
    - distribution: ndarray, shape (n_samples,)
         Uniform distribution (each entry equals 1/n_samples).
    """

Transport Classes for Domain Adaptation

Base Transport Class

class ot.da.BaseTransport:
    """
    Base class for optimal transport-based domain adaptation.
    
    Provides common interface and functionality for all transport-based
    domain adaptation methods.
    
    Parameters:
    - log: bool, default=False
         Whether to store optimization logs.
    - verbose: bool, default=False
         Print information during fitting.
    - out_of_sample_map: str, default='ferradans'
         Out-of-sample mapping method for new data points.
    """
    
    def fit(self, Xs=None, Xt=None, ys=None, yt=None):
        """
        Build a coupling matrix from source and target sets.
        
        Parameters:
        - Xs: array-like, shape (n_source_samples, n_features)
             Source domain samples.
        - Xt: array-like, shape (n_target_samples, n_features)
             Target domain samples.
        - ys: array-like, shape (n_source_samples,), optional
             Source domain labels.
        - yt: array-like, shape (n_target_samples,), optional
             Target domain labels.
        
        Returns:
        - self: BaseTransport instance
        """
    
    def transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):
        """
        Transform source samples to target domain.
        
        Parameters:
        - Xs: array-like, shape (n_samples, n_features), optional
             Source samples to transform.
        - Xt: array-like, shape (n_samples, n_features), optional  
             Target samples to inverse transform.
        - ys: array-like, optional
             Source labels.
        - yt: array-like, optional
             Target labels.
        - batch_size: int, default=128
             Batch size for large-scale transformations.
        
        Returns:
        - transformed_samples: ndarray
             Samples transformed to target domain.
        """
    
    def transform_labels(self, ys=None):
        """
        Propagate source labels to target domain.
        
        Parameters:
        - ys: array-like, shape (n_source_samples,)
             Source labels to propagate.
        
        Returns:
        - target_labels: ndarray, shape (n_target_samples,)
             Labels assigned to target samples.
        """
    
    def inverse_transform(self, Xs=None, Xt=None, ys=None, yt=None, batch_size=128):
        """
        Transform target samples to source domain.
        
        Parameters: Similar to transform()
        
        Returns:
        - inverse_transformed: ndarray
             Target samples transformed to source domain.
        """

Linear Transport Methods

class ot.da.LinearTransport(BaseTransport):
    """
    Linear optimal transport for domain adaptation.
    
    Learns a linear transformation matrix for mapping between domains
    based on optimal transport theory.
    
    Parameters:
    - reg: float, default=1e-8
         Regularization parameter for matrix inversion.
    - bias: bool, default=False
         Whether to estimate bias term.
    - log: bool, default=False
    - verbose: bool, default=False
    """

class ot.da.LinearGWTransport(BaseTransport):
    """
    Linear Gromov-Wasserstein transport for domain adaptation.
    
    Uses Gromov-Wasserstein distance to handle domains with different
    feature spaces by comparing internal structure rather than features directly.
    
    Parameters:
    - reg: float, default=1e-8
         Regularization parameter.
    - alpha: float, default=0.5
         GW optimization step size.
    - max_iter: int, default=100
         Maximum GW iterations.
    - tol: float, default=1e-6
         GW convergence tolerance.
    """

Sinkhorn-based Transport

class ot.da.SinkhornTransport(BaseTransport):
    """
    Sinkhorn transport for domain adaptation.
    
    Uses entropic regularization and Sinkhorn algorithm for efficient
    computation of transport plans between domains.
    
    Parameters:
    - reg_e: float, default=1.0
         Entropic regularization parameter.
    - max_iter: int, default=1000
         Maximum Sinkhorn iterations.
    - tol: float, default=1e-9
         Sinkhorn convergence tolerance.
    - verbose: bool, default=False
    - log: bool, default=False
    - metric: str, default='sqeuclidean'
         Ground metric for cost matrix computation.
    - norm: str, optional
         Cost matrix normalization method.
    - distribution_estimation: callable, default=distribution_estimation_uniform
         Method for estimating sample distributions.
    - out_of_sample_map: str, default='ferradans'
         Out-of-sample mapping technique.
    - limit_max: float, default=np.infty
         Maximum value for cost matrix entries.
    """

class ot.da.EMDTransport(BaseTransport):
    """
    Exact EMD transport for domain adaptation.
    
    Uses exact optimal transport (Earth Mover's Distance) without
    regularization for precise domain adaptation.
    
    Parameters:
    - metric: str, default='sqeuclidean'
         Ground metric for cost computation.
    - norm: str, optional
         Cost normalization method.
    - log: bool, default=False
    - verbose: bool, default=False
    - distribution_estimation: callable, default=distribution_estimation_uniform
    - out_of_sample_map: str, default='ferradans'
    - limit_max: float, default=np.infty
         Cost matrix entry limit.
    """

Label-Regularized Transport Classes

class ot.da.SinkhornLpl1Transport(BaseTransport):
    """
    Sinkhorn transport with L1 label regularization.
    
    Incorporates source domain labels to guide transport using L1 penalty
    on cross-class transport.
    
    Parameters:
    - reg_e: float, default=1.0
         Entropic regularization.
    - reg_cl: float, default=0.1
         Label regularization parameter.
    - max_iter: int, default=10
         Maximum outer iterations.
    - max_inner_iter: int, default=200
         Maximum inner Sinkhorn iterations.
    - log: bool, default=False
    - verbose: bool, default=False
    - metric: str, default='sqeuclidean'
    """

class ot.da.SinkhornL1l2Transport(BaseTransport):
    """
    Sinkhorn transport with L1-L2 group lasso regularization.
    
    Combines L1 sparsity with L2 group penalties for structured
    domain adaptation.
    
    Parameters:
    - reg_e: float, default=1.0
         Entropic regularization.
    - reg_cl: float, default=0.1
         L1 regularization.
    - reg_l: float, default=0.1
         L2 group regularization.
    - max_iter: int, default=10
    - max_inner_iter: int, default=200
    - tol: float, default=1e-9
    """

class ot.da.EMDLaplaceTransport(BaseTransport):
    """
    EMD transport with Laplacian regularization.
    
    Uses Laplacian penalty to ensure smooth transport respecting
    data manifold structure.
    
    Parameters:
    - reg_lap: float, default=1.0
         Laplacian regularization parameter.
    - reg_src: float, default=0.5
         Source regularization.
    - metric: str, default='sqeuclidean'
    - norm: str, optional
    - similarity: str, default='knn'
         Method for similarity matrix construction.
    - similarity_param: int, default=7
         Parameter for similarity computation (e.g., k for knn).
    - max_iter: int, default=10
    """

Advanced Transport Methods

class ot.da.MappingTransport(BaseTransport):
    """
    Optimal transport with learned mappings.
    
    Learns parametric mappings (linear or kernel-based) that approximate
    the optimal transport map.
    
    Parameters:
    - mu: float, default=1e0
         Regularization parameter for mapping learning.
    - eta: float, default=1e-8
         Numerical regularization.
    - bias: bool, default=True
         Whether to learn bias terms.
    - metric: str, default='sqeuclidean'
    - norm: str, optional
    - kernel: str, default='linear'
         Kernel type ('linear', 'gaussian', 'rbf').
    - sigma: float, default=1.0
         Kernel bandwidth (for Gaussian/RBF kernels).
    - max_iter: int, default=100
    - tol: float, default=1e-5
    - max_inner_iter: int, default=10
    - inner_tol: float, default=1e-6
    - log: bool, default=False
    - verbose: bool, default=False
    - verbose2: bool, default=False
    """

class ot.da.UnbalancedSinkhornTransport(BaseTransport):
    """
    Unbalanced Sinkhorn transport for domain adaptation.
    
    Handles domain adaptation with different marginal distributions
    using unbalanced optimal transport.
    
    Parameters:
    - reg_e: float, default=1.0
         Entropic regularization.
    - reg_m: float, default=1.0
         Marginal relaxation parameter.
    - method: str, default='sinkhorn'
         Unbalanced algorithm variant.
    - max_iter: int, default=1000
    - tol: float, default=1e-9
    - verbose: bool, default=False
    - log: bool, default=False
    """

class ot.da.JCPOTTransport(BaseTransport):
    """
    Joint Characteristic-Optimal-Transport (JCPOT) for multi-source adaptation.
    
    Handles multiple source domains simultaneously using joint optimal
    transport formulation.
    
    Parameters:
    - reg_e: float, default=1.0
         Entropic regularization.
    - max_iter: int, default=10
    - tol: float, default=1e-6
    - verbose: bool, default=False
    - log: bool, default=False
    - metric: str, default='sqeuclidean'
    """

class ot.da.NearestBrenierPotential(BaseTransport):
    """
    Transport using nearest Brenier potential approximation.
    
    Learns optimal transport maps through Brenier potential estimation
    for smooth and invertible domain adaptation.
    
    Parameters:
    - reg: float, default=1e-3
         Regularization for potential learning.
    - max_iter: int, default=100
    - tol: float, default=1e-6
    """

Usage Examples

Basic Domain Adaptation

import ot
import numpy as np
from sklearn.datasets import make_classification

# Generate source and target domains
n_source, n_target = 150, 100
n_features = 2

# Source domain
Xs, ys = make_classification(n_samples=n_source, n_features=n_features, 
                           n_redundant=0, n_informative=2, 
                           random_state=1, n_clusters_per_class=1)

# Target domain (shifted and rotated)
Xt, yt = make_classification(n_samples=n_target, n_features=n_features,
                           n_redundant=0, n_informative=2,
                           random_state=42, n_clusters_per_class=1)

# Apply domain shift
angle = np.pi / 6
rotation = np.array([[np.cos(angle), -np.sin(angle)],
                    [np.sin(angle), np.cos(angle)]])
Xt = Xt @ rotation + [1, 1]

print(f"Source domain shape: {Xs.shape}")
print(f"Target domain shape: {Xt.shape}")

Sinkhorn Transport Adaptation

# Initialize Sinkhorn transport
sinkhorn_adapter = ot.da.SinkhornTransport(reg_e=0.1, verbose=True)

# Fit the transport
sinkhorn_adapter.fit(Xs=Xs, Xt=Xt)

# Transform source samples to target domain
Xs_adapted = sinkhorn_adapter.transform(Xs=Xs)

print("Adaptation completed")
print(f"Adapted source shape: {Xs_adapted.shape}")
print(f"Transport cost: {sinkhorn_adapter.coupling_.sum()}")

Label-Regularized Adaptation

# Use source labels for better adaptation
label_adapter = ot.da.SinkhornLpl1Transport(
    reg_e=0.1, reg_cl=0.1, verbose=True
)

# Fit with source labels
label_adapter.fit(Xs=Xs, ys=ys, Xt=Xt)

# Transform and propagate labels
Xs_label_adapted = label_adapter.transform(Xs=Xs)
yt_predicted = label_adapter.transform_labels(ys=ys)

print(f"Label-adapted source shape: {Xs_label_adapted.shape}")
print(f"Predicted target labels shape: {yt_predicted.shape}")

Multi-Method Comparison

# Compare different adaptation methods
methods = {
    'EMD': ot.da.EMDTransport(),
    'Sinkhorn': ot.da.SinkhornTransport(reg_e=0.1),
    'Linear': ot.da.LinearTransport(),
    'Unbalanced': ot.da.UnbalancedSinkhornTransport(reg_e=0.1, reg_m=1.0)
}

adapted_sources = {}

for name, method in methods.items():
    print(f"\nFitting {name} transport...")
    method.fit(Xs=Xs, Xt=Xt)
    adapted_sources[name] = method.transform(Xs=Xs)
    
    # Compute adaptation quality (distance to target centroid)
    target_center = np.mean(Xt, axis=0)
    adapted_center = np.mean(adapted_sources[name], axis=0)
    distance = np.linalg.norm(target_center - adapted_center)
    print(f"{name} - Distance to target center: {distance:.4f}")

Out-of-Sample Adaptation

# Generate new source samples for out-of-sample testing
Xs_new = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], 50)

# Adapt new samples using trained transport
sinkhorn_adapter = ot.da.SinkhornTransport(reg_e=0.1)
sinkhorn_adapter.fit(Xs=Xs, Xt=Xt)

# Transform new samples
Xs_new_adapted = sinkhorn_adapter.transform(Xs=Xs_new)

print(f"New source samples: {Xs_new.shape}")
print(f"Adapted new samples: {Xs_new_adapted.shape}")

JCPOT Multi-Source Adaptation

# Create multiple source domains
n_sources = 3
source_domains = []
source_labels = []

for i in range(n_sources):
    Xs_i, ys_i = make_classification(n_samples=100, n_features=2,
                                   random_state=i, n_clusters_per_class=1)
    # Apply different shifts to each source
    Xs_i = Xs_i + [i*0.5, i*0.3]
    source_domains.append(Xs_i)
    source_labels.append(ys_i)

# JCPOT adaptation
jcpot_adapter = ot.da.JCPOTTransport(reg_e=0.1, verbose=True)

# Fit multiple sources to single target
jcpot_adapter.fit(Xs=source_domains, ys=source_labels, Xt=Xt, yt=yt)

print("JCPOT multi-source adaptation completed")

Advanced Mapping Transport

# Use mapping transport with RBF kernel
mapping_adapter = ot.da.MappingTransport(
    kernel='rbf', sigma=1.0, mu=1e-1, verbose=True
)

mapping_adapter.fit(Xs=Xs, Xt=Xt)
Xs_mapped = mapping_adapter.transform(Xs=Xs)

print("Mapping transport with RBF kernel completed")

# The learned mapping can be applied to new data
Xs_new_mapped = mapping_adapter.transform(Xs=Xs_new)
print(f"New samples mapped: {Xs_new_mapped.shape}")

Performance Evaluation

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# Train classifier on adapted source data
knn = KNeighborsClassifier(n_neighbors=3)

# Test different adaptations
results = {}

for name, Xs_adapted in adapted_sources.items():
    # Train on adapted source
    knn.fit(Xs_adapted, ys)
    
    # Predict on target (when labels available)
    if len(np.unique(yt)) > 1:  # Check if target has multiple classes
        yt_pred = knn.predict(Xt)
        accuracy = accuracy_score(yt, yt_pred)
        results[name] = accuracy
        print(f"{name} adaptation accuracy: {accuracy:.3f}")

# Baseline: no adaptation
knn.fit(Xs, ys)
if len(np.unique(yt)) > 1:
    yt_pred_baseline = knn.predict(Xt)
    baseline_acc = accuracy_score(yt, yt_pred_baseline)
    print(f"No adaptation accuracy: {baseline_acc:.3f}")

Applications

Computer Vision

  • Cross-dataset adaptation: Adapting models trained on one image dataset to another
  • Domain shift: Handling changes in lighting, camera, or image style
  • Synthetic-to-real: Adapting from synthetic training data to real images

Natural Language Processing

  • Cross-lingual adaptation: Transferring models between languages
  • Domain-specific text: Adapting from general to domain-specific corpora
  • Temporal adaptation: Handling language evolution over time

Biomedical Applications

  • Cross-study adaptation: Adapting between different clinical studies
  • Multi-site data: Handling batch effects across research sites
  • Cross-species: Transferring knowledge between related organisms

The ot.da module provides comprehensive tools for transport-based domain adaptation, offering both theoretical rigor and practical effectiveness for bridging distribution gaps in machine learning applications.

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json