CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

gromov-wasserstein.mddocs/

Gromov-Wasserstein Distances

The ot.gromov module provides algorithms for computing Gromov-Wasserstein (GW) distances and their variants. These methods enable optimal transport between structured data by comparing the internal geometry of metric spaces rather than requiring a common embedding space.

Core Gromov-Wasserstein Functions

Basic GW Distance Computation

def ot.gromov.gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute Gromov-Wasserstein distance between two metric spaces.
    
    Solves the quadratic assignment problem to find optimal correspondences
    between points in different metric spaces by preserving pairwise distances.
    
    Parameters:
    - C1: array-like, shape (n1, n1)
         Intra-structure cost matrix for source space (distances/similarities).
    - C2: array-like, shape (n2, n2)
         Intra-structure cost matrix for target space.
    - p: array-like, shape (n1,)
         Distribution over source space. Must be positive and sum to 1.
    - q: array-like, shape (n2,)
         Distribution over target space. Must be positive and sum to 1.
    - loss_fun: str or callable, default='square_loss'
         Loss function for structure preservation. Options: 'square_loss', 'kl_loss'
         or custom function with signature loss_fun(C1, C2, T).
    - alpha: float, default=0.5
         Step size parameter for the gradient descent algorithm.
    - armijo: bool, default=False
         Use Armijo line search for adaptive step size.
    - log: bool, default=False
         Return optimization log with convergence details.
    - max_iter: int, default=1000
         Maximum number of iterations.
    - tol_rel: float, default=1e-9
         Relative tolerance for convergence.
    - tol_abs: float, default=1e-9
         Absolute tolerance for convergence.
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
         Optimal GW transport plan between the two spaces.
    - log: dict (if log=True)
         Contains 'gw_dist': GW distance, 'err': convergence errors,
         'T': transport plans at each iteration.
    """

def ot.gromov.gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute Gromov-Wasserstein squared distance (cost only).
    
    More efficient when only the distance value is needed.
    
    Parameters: Same as gromov_wasserstein()
    
    Returns:
    - gw_distance: float
         Gromov-Wasserstein distance between the two spaces.
    - log: dict (if log=True)
    """

def ot.gromov.solve_gromov_linesearch(C1, C2, p, q, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
    """
    Solve GW problem with automatic line search for optimal step size.
    
    Parameters:
    - C1, C2: array-like
         Cost matrices for source and target spaces.
    - p, q: array-like
         Distributions over source and target spaces.
    - loss_fun: str or callable
         Loss function for GW computation.
    - alpha_min: float, optional
         Minimum step size for line search.
    - alpha_max: float, optional
         Maximum step size for line search.
    - log: bool, default=False
    - numItermax: int, default=1000
    - stopThr: float, default=1e-9
    - verbose: bool, default=False
    
    Returns:
    - transport_plan: ndarray
    - log: dict (if log=True)
    """

Fused Gromov-Wasserstein

def ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute Fused Gromov-Wasserstein distance combining structure and features.
    
    Combines standard optimal transport (based on feature cost M) with 
    Gromov-Wasserstein transport (based on structure costs C1, C2).
    
    Parameters:
    - M: array-like, shape (n1, n2)
         Feature cost matrix between source and target samples.
    - C1: array-like, shape (n1, n1)
         Intra-structure cost matrix for source space.
    - C2: array-like, shape (n2, n2)
         Intra-structure cost matrix for target space.
    - p: array-like, shape (n1,)
         Source distribution.
    - q: array-like, shape (n2,)
         Target distribution.
    - loss_fun: str or callable, default='square_loss'
         Loss function for structure preservation.
    - alpha: float, default=0.5
         Trade-off parameter between structure (α) and features (1-α).
         α=1 gives pure GW, α=0 gives pure Wasserstein.
    - armijo: bool, default=False
         Use Armijo line search.
    - log: bool, default=False
    - max_iter: int, default=1000
    - tol_rel: float, default=1e-9
    - tol_abs: float, default=1e-9
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
         Optimal FGW transport plan.
    - log: dict (if log=True)
    """

def ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute Fused Gromov-Wasserstein squared distance (cost only).
    
    Parameters: Same as fused_gromov_wasserstein()
    
    Returns:
    - fgw_distance: float
    - log: dict (if log=True)
    """

Barycenter Algorithms

def ot.gromov.gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None, **kwargs):
    """
    Compute Gromov-Wasserstein barycenter of multiple metric spaces.
    
    Finds the barycenter space that minimizes the sum of GW distances
    to all input spaces, optimizing both the structure and distribution.
    
    Parameters:
    - N: int
         Size of the barycenter space (number of points).
    - Cs: list of arrays
         List of intra-structure cost matrices for input spaces.
         Each Cs[i] has shape (ni, ni).
    - ps: list of arrays
         List of distributions for input spaces.
         Each ps[i] has shape (ni,).
    - p: array-like, shape (N,)
         Distribution for the barycenter space.
    - lambdas: array-like, shape (n_spaces,)
         Weights for combining input spaces in barycenter.
    - loss_fun: str or callable, default='square_loss'
         Loss function for GW computation.
    - max_iter: int, default=1000
         Maximum iterations for barycenter optimization.
    - tol: float, default=1e-9
         Convergence tolerance.
    - verbose: bool, default=False
         Print optimization information.
    - log: bool, default=False
         Return optimization log.
    - init_C: array-like, shape (N, N), optional
         Initial barycenter structure matrix. Random if None.
    - random_state: int, optional
         Random seed for reproducible initialization.
    
    Returns:
    - barycenter_structure: ndarray, shape (N, N)
         Optimal barycenter intra-structure cost matrix.
    - log: dict (if log=True)
         Contains convergence information and transport plans.
    """

def ot.gromov.fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
    """
    Compute Fused Gromov-Wasserstein barycenter with features and structure.
    
    Parameters:
    - N: int
         Barycenter size.
    - Ys: list of arrays
         List of feature matrices for input spaces.
    - Cs: list of arrays  
         List of structure matrices for input spaces.
    - ps: list of arrays
         List of distributions for input spaces.
    - lambdas: array-like
         Weights for space combination.
    - alpha: float
         Trade-off between structure and features.
    - fixed_structure: bool, default=False
         Whether to fix the barycenter structure.
    - fixed_features: bool, default=False
         Whether to fix the barycenter features.
    - p: array-like, optional
         Barycenter distribution.
    - loss_fun: str or callable, default='square_loss'
    - max_iter: int, default=100
    - tol: float, default=1e-9
    - verbose: bool, default=False
    - log: bool, default=False
    - init_C: array-like, optional
         Initial barycenter structure.
    - init_X: array-like, optional
         Initial barycenter features.
    - random_state: int, optional
    
    Returns:
    - barycenter_features: ndarray, shape (N, d)
    - barycenter_structure: ndarray, shape (N, N)  
    - log: dict (if log=True)
    """

Entropic Regularized Methods

def ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
    """
    Compute entropic regularized Gromov-Wasserstein distance.
    
    Combines GW formulation with entropic regularization for better
    computational properties and differentiability.
    
    Parameters:
    - C1: array-like, shape (n1, n1)
         Source structure matrix.
    - C2: array-like, shape (n2, n2)
         Target structure matrix.
    - p: array-like, shape (n1,)
         Source distribution.
    - q: array-like, shape (n2,)
         Target distribution.
    - loss_fun: str or callable, default='square_loss'
    - epsilon: float, default=0.1
         Entropic regularization parameter.
    - symmetric: bool, optional
         Whether loss function is symmetric.
    - G0: array-like, optional
         Initial transport plan.
    - max_iter: int, default=1000
    - tol: float, default=1e-9
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
    - log: dict (if log=True)
    """

def ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
    """
    Compute entropic regularized GW distance (cost only).
    
    Parameters: Same as entropic_gromov_wasserstein()
    
    Returns:
    - gw_distance: float
    - log: dict (if log=True)
    """

def ot.gromov.entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
    """
    Compute entropic regularized GW barycenters.
    
    Parameters:
    - N: int
         Barycenter size.
    - Cs: list of arrays
         Structure matrices.
    - ps: list of arrays
         Input distributions.
    - p: array-like
         Barycenter distribution.
    - lambdas: array-like
         Combination weights.
    - loss_fun: str or callable, default='square_loss'
    - epsilon: float, default=0.1
         Entropic regularization.
    - symmetric: bool, default=True
    - max_iter: int, default=1000
    - tol: float, default=1e-9
    - verbose: bool, default=False
    - log: bool, default=False
    - init_C: array-like, optional
    - random_state: int, optional
    
    Returns:
    - barycenter_structure: ndarray, shape (N, N)
    - log: dict (if log=True)
    """

def ot.gromov.entropic_fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
    """
    Compute entropic regularized Fused GW distance.
    
    Parameters:
    - M: array-like, shape (n1, n2)
         Feature cost matrix.
    - C1, C2: array-like
         Structure matrices.
    - p, q: array-like
         Distributions.
    - loss_fun: str or callable, default='square_loss'
    - epsilon: float, default=0.1
         Entropic regularization.
    - alpha: float, default=0.5
         Structure/feature trade-off.
    - symmetric: bool, optional
    - G0: array-like, optional
         Initial transport plan.
    - max_iter: int, default=1000
    - tol: float, default=1e-9
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
    - log: dict (if log=True)
    """

def ot.gromov.entropic_fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
    """
    Compute entropic regularized FGW distance (cost only).
    
    Parameters: Same as entropic_fused_gromov_wasserstein()
    
    Returns:
    - fgw_distance: float
    - log: dict (if log=True)
    """

def ot.gromov.entropic_fused_gromov_barycenters(N, Ys, Cs, ps, lambdas, alpha, epsilon=0.1, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None):
    """
    Compute entropic regularized FGW barycenters.
    
    Parameters: Similar to fgw_barycenters() with additional epsilon parameter
    
    Returns:
    - barycenter_features: ndarray
    - barycenter_structure: ndarray
    - log: dict (if log=True)
    """

Semi-relaxed Methods

def ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute semi-relaxed Gromov-Wasserstein distance.
    
    Relaxes the constraint on one marginal, allowing for partial transport
    from source to target while preserving target marginal.
    
    Parameters:
    - C1: array-like, shape (n1, n1)
         Source structure matrix.
    - C2: array-like, shape (n2, n2)
         Target structure matrix.
    - p: array-like, shape (n1,)
         Source distribution (will be relaxed).
    - loss_fun: str or callable, default='square_loss'
    - symmetric: bool, optional
    - alpha: float, default=0.5
         Step size parameter.
    - G0: array-like, optional
         Initial transport plan.
    - log: bool, default=False
    - max_iter: int, default=1000
    - tol_rel: float, default=1e-9
    - tol_abs: float, default=1e-9
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
         Semi-relaxed transport plan.
    - log: dict (if log=True)
    """

def ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute semi-relaxed GW distance (cost only).
    
    Parameters: Same as semirelaxed_gromov_wasserstein()
    
    Returns:
    - sr_gw_distance: float
    - log: dict (if log=True)
    """

def ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute semi-relaxed Fused GW distance.
    
    Parameters:
    - M: array-like, shape (n1, n2)
         Feature cost matrix.
    - Other parameters same as semirelaxed_gromov_wasserstein()
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
    - log: dict (if log=True)
    """

def ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute semi-relaxed FGW distance (cost only).
    
    Parameters: Same as semirelaxed_fused_gromov_wasserstein()
    
    Returns:
    - sr_fgw_distance: float
    - log: dict (if log=True)
    """

def ot.gromov.solve_semirelaxed_gromov_linesearch(C1, C2, p, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
    """
    Solve semi-relaxed GW with line search optimization.
    
    Parameters: Similar to solve_gromov_linesearch()
    
    Returns:
    - transport_plan: ndarray
    - log: dict (if log=True)
    """

Partial Methods

def ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute partial Gromov-Wasserstein distance.
    
    Allows transport of only a fraction of the total mass, useful when
    spaces have different sizes or contain outliers.
    
    Parameters:
    - C1: array-like, shape (n1, n1)
         Source structure matrix.
    - C2: array-like, shape (n2, n2)
         Target structure matrix.
    - p: array-like, shape (n1,)
         Source distribution.
    - q: array-like, shape (n2,)
         Target distribution.
    - m: float, optional
         Fraction of mass to transport (default: min(sum(p), sum(q))).
    - loss_fun: str or callable, default='square_loss'
    - alpha: float, default=0.5
    - armijo: bool, default=False
    - log: bool, default=False
    - max_iter: int, default=1000
    - tol_rel: float, default=1e-9
    - tol_abs: float, default=1e-9
    
    Returns:
    - transport_plan: ndarray, shape (n1, n2)
         Partial GW transport plan.
    - log: dict (if log=True)
    """

def ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
    """
    Compute partial GW distance (cost only).
    
    Parameters: Same as partial_gromov_wasserstein()
    
    Returns:
    - partial_gw_distance: float
    - log: dict (if log=True)
    """

def ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
    """
    Compute entropic regularized partial GW distance.
    
    Parameters:
    - C1, C2: array-like
         Structure matrices.
    - p, q: array-like
         Distributions.
    - reg: float
         Entropic regularization parameter.
    - m: float, optional
         Mass to transport.
    - loss_fun: str or callable, default='square_loss'
    - G0: array-like, optional
    - max_iter: int, default=1000
    - tol: float, default=1e-9
    - verbose: bool, default=False
    - log: bool, default=False
    
    Returns:
    - transport_plan: ndarray
    - log: dict (if log=True)
    """

def ot.gromov.entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
    """
    Compute entropic regularized partial GW distance (cost only).
    
    Parameters: Same as entropic_partial_gromov_wasserstein()
    
    Returns:
    - partial_gw_distance: float
    - log: dict (if log=True)
    """

Dictionary Learning Methods

def ot.gromov.gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0.0, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None, **kwargs):
    """
    Learn dictionary of structures using GW distances.
    
    Learns a dictionary of prototype structures that can represent
    input structures as sparse combinations.
    
    Parameters:
    - Cs: list of arrays
         Input structure matrices to learn from.
    - D: int
         Dictionary size (number of atoms).
    - nt: int
         Size of each dictionary atom.
    - reg: float, default=0.0
         Entropic regularization for GW computation.
    - ps: list of arrays, optional
         Distributions for input structures.
    - q: array-like, optional
         Distribution for dictionary atoms.
    - epochs: int, default=20
         Number of learning epochs.
    - batch_size: int, default=32
         Mini-batch size for learning.
    - learning_rate: float, default=1.0
         Learning rate for dictionary updates.
    - proj_sparse_regul: float, default=0.1
         Sparsity regularization for projections.
    - verbose: bool, default=False
    - random_state: int, optional
    
    Returns:
    - dictionary: list of arrays
         Learned dictionary of structure matrices.
    - log: dict
         Learning statistics and convergence information.
    """

def ot.gromov.gromov_wasserstein_linear_unmixing(C, Cdict, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False, **kwargs):
    """
    Unmix structure using learned GW dictionary.
    
    Decomposes input structure as sparse combination of dictionary atoms.
    
    Parameters:
    - C: array-like, shape (n, n)
         Structure matrix to unmix.
    - Cdict: list of arrays
         Dictionary of structure atoms.
    - reg: float, default=0.0
         Entropic regularization.
    - p: array-like, optional
         Distribution for input structure.
    - q: array-like, optional
         Distribution for dictionary atoms.
    - tol_outer: float, default=1e-6
         Outer loop tolerance.
    - tol_inner: float, default=1e-6
         Inner loop tolerance.
    - max_iter_outer: int, default=20
    - max_iter_inner: int, default=200
    - verbose: bool, default=False
    
    Returns:
    - coefficients: ndarray
         Sparse coefficients for dictionary combination.
    - log: dict
         Unmixing optimization information.
    """

def ot.gromov.fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, reg=0.0, alpha=0.5, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None):
    """
    Learn dictionary for FGW (structure + features).
    
    Parameters: Extends gromov_wasserstein_dictionary_learning() with:
    - Ys: list of arrays
         Feature matrices for input data.
    - alpha: float, default=0.5
         Structure/feature trade-off.
    
    Returns:
    - structure_dictionary: list of arrays
    - feature_dictionary: list of arrays
    - log: dict
    """

def ot.gromov.fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha=0.5, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False):
    """
    Unmix FGW data using learned dictionary.
    
    Parameters: Extends gromov_wasserstein_linear_unmixing() with:
    - Y: array-like
         Feature matrix to unmix.
    - Ydict: list of arrays
         Feature dictionary atoms.
    - alpha: float, default=0.5
    
    Returns:
    - coefficients: ndarray
    - log: dict
    """

Utility Functions

def ot.gromov.init_matrix(C1, C2, p, q, loss_fun='square_loss', random_state=None):
    """
    Initialize transport matrix for GW algorithms.
    
    Parameters:
    - C1: array-like, shape (n1, n1)
    - C2: array-like, shape (n2, n2)
    - p: array-like, shape (n1,)
    - q: array-like, shape (n2,)
    - loss_fun: str or callable, default='square_loss'
    - random_state: int, optional
    
    Returns:
    - G0: ndarray, shape (n1, n2)
         Initial transport matrix.
    """

def ot.gromov.tensor_product(constC, hC1, hC2, T):
    """
    Compute tensor product for GW gradient computation.
    
    Parameters:
    - constC: ndarray
         Constant term in GW formulation.
    - hC1: ndarray
         Source structure term.
    - hC2: ndarray
         Target structure term.
    - T: ndarray
         Current transport plan.
    
    Returns:
    - tensor_prod: ndarray
         Tensor product result.
    """

def ot.gromov.gwloss(constC, hC1, hC2, T):
    """
    Compute Gromov-Wasserstein loss function value.
    
    Parameters:
    - constC: ndarray
    - hC1: ndarray
    - hC2: ndarray  
    - T: ndarray
         Transport plan.
    
    Returns:
    - loss: float
         GW loss value.
    """

def ot.gromov.gwggrad(constC, hC1, hC2, T):
    """
    Compute Gromov-Wasserstein gradient.
    
    Parameters:
    - constC: ndarray
    - hC1: ndarray
    - hC2: ndarray
    - T: ndarray
    
    Returns:
    - gradient: ndarray
         GW objective gradient.
    """

def ot.gromov.update_barycenter_structure(Ts, Cs, lambdas, p, loss_fun='square_loss'):
    """
    Update barycenter structure matrix.
    
    Parameters:
    - Ts: list of arrays
         Transport plans to input spaces.
    - Cs: list of arrays
         Input structure matrices.
    - lambdas: array-like
         Barycenter weights.
    - p: array-like
         Barycenter distribution.
    - loss_fun: str or callable, default='square_loss'
    
    Returns:
    - C_barycenter: ndarray
         Updated barycenter structure.
    """

def ot.gromov.update_barycenter_feature(Ts, Ys, lambdas, p):
    """
    Update barycenter feature matrix.
    
    Parameters:
    - Ts: list of arrays
         Transport plans.
    - Ys: list of arrays
         Input feature matrices.
    - lambdas: array-like
    - p: array-like
    
    Returns:
    - Y_barycenter: ndarray
         Updated barycenter features.
    """

Usage Examples

Basic Gromov-Wasserstein

import ot
import numpy as np

# Create structure matrices (e.g., distance matrices)
n1, n2 = 10, 15
C1 = np.random.rand(n1, n1)
C1 = (C1 + C1.T) / 2  # Make symmetric
C2 = np.random.rand(n2, n2) 
C2 = (C2 + C2.T) / 2

# Create distributions
p = ot.unif(n1)
q = ot.unif(n2)

# Compute GW distance
gw_plan = ot.gromov.gromov_wasserstein(C1, C2, p, q, verbose=True)
gw_dist = ot.gromov.gromov_wasserstein2(C1, C2, p, q)

print(f"GW distance: {gw_dist}")
print(f"Transport plan shape: {gw_plan.shape}")

Fused Gromov-Wasserstein

# Feature cost matrix
d = 3
X1 = np.random.randn(n1, d)
X2 = np.random.randn(n2, d)
M = ot.dist(X1, X2)

# Structure-feature trade-off
alpha = 0.7  # More weight on structure

# Compute FGW
fgw_plan = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, alpha=alpha)
fgw_dist = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, alpha=alpha)

print(f"FGW distance: {fgw_dist}")

GW Barycenter

# Multiple structures
n_spaces = 5
Cs = [np.random.rand(8, 8) for _ in range(n_spaces)]
Cs = [(C + C.T)/2 for C in Cs]  # Make symmetric

ps = [ot.unif(8) for _ in range(n_spaces)]
lambdas = ot.unif(n_spaces)

# Barycenter parameters
N = 6  # Barycenter size
p_barycenter = ot.unif(N)

# Compute barycenter
C_barycenter = ot.gromov.gromov_barycenters(N, Cs, ps, p_barycenter, lambdas, verbose=True)

print(f"Barycenter structure shape: {C_barycenter.shape}")

Entropic GW

# Add entropic regularization
epsilon = 0.05

# Compute entropic GW
egw_plan = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, epsilon=epsilon)
egw_dist = ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, epsilon=epsilon)

print(f"Entropic GW distance: {egw_dist}")

Partial GW for Outlier Robustness

# Transport only 70% of mass
m = 0.7

# Compute partial GW
pgw_plan = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m)
pgw_dist = ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=m)

print(f"Partial GW distance: {pgw_dist}")
print(f"Transported mass: {np.sum(pgw_plan)}")

Quantized and Sampling Methods

Large-scale methods using graph partitioning, quantization, and sampling approaches.

def quantized_fused_gromov_wasserstein(C1, C2, Y1, Y2, a=None, b=None, alpha=0.5, reg=0.1, num_node_class=8, **kwargs):
    """
    Solve quantized FGW using graph partitioning for computational efficiency.
    
    Parameters:
    - C1, C2: array-like, structure matrices
    - Y1, Y2: array-like, feature matrices
    - a, b: array-like, distributions
    - alpha: float, structure/feature weight
    - reg: float, regularization parameter
    - num_node_class: int, number of partitions
    
    Returns:
    - quantized transport plan
    """

def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0.0, rank=10, numItermax=100, stopThr=1e-5, log=False):
    """
    Solve GW using low-rank factorization for large-scale problems.
    """

def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', nb_samples_grad=100, log=False, **kwargs):
    """
    Solve GW using sampling for gradient computation.
    """

def get_graph_partition(C1, num_node_class=8, part_method='louvain'):
    """
    Partition graph for quantized methods.
    """

Unbalanced Methods

Unbalanced variants allowing different total masses.

def fused_unbalanced_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, rho=1.0, rho2=1.0, **kwargs):
    """
    Solve unbalanced FGW with marginal relaxation penalties.
    
    Parameters:
    - M: array-like, feature cost matrix
    - C1, C2: array-like, structure matrices
    - p, q: array-like, measures (can have different masses)
    - epsilon: float, entropic regularization
    - alpha: float, structure/feature trade-off
    - rho, rho2: float, marginal relaxation penalties
    
    Returns:
    - unbalanced transport plan
    """

def unbalanced_co_optimal_transport(X_s, X_t, C1, C2, p, q, epsilon=0.1, rho=1.0, rho2=1.0, **kwargs):
    """
    Solve unbalanced co-optimal transport.
    """

Import Statements

import ot.gromov
from ot.gromov import gromov_wasserstein, gromov_wasserstein2
from ot.gromov import fused_gromov_wasserstein, fused_gromov_wasserstein2
from ot.gromov import gromov_barycenters, fgw_barycenters
from ot.gromov import entropic_gromov_wasserstein, entropic_fused_gromov_wasserstein
from ot.gromov import semirelaxed_gromov_wasserstein, partial_gromov_wasserstein
from ot.gromov import gromov_wasserstein_dictionary_learning, quantized_fused_gromov_wasserstein

The ot.gromov module provides powerful tools for structured optimal transport, enabling comparison of data with internal geometric structure such as graphs, point clouds, and other metric spaces where traditional optimal transport is not directly applicable.

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json