Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The ot.gromov module provides algorithms for computing Gromov-Wasserstein (GW) distances and their variants. These methods enable optimal transport between structured data by comparing the internal geometry of metric spaces rather than requiring a common embedding space.
def ot.gromov.gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute Gromov-Wasserstein distance between two metric spaces.
Solves the quadratic assignment problem to find optimal correspondences
between points in different metric spaces by preserving pairwise distances.
Parameters:
- C1: array-like, shape (n1, n1)
Intra-structure cost matrix for source space (distances/similarities).
- C2: array-like, shape (n2, n2)
Intra-structure cost matrix for target space.
- p: array-like, shape (n1,)
Distribution over source space. Must be positive and sum to 1.
- q: array-like, shape (n2,)
Distribution over target space. Must be positive and sum to 1.
- loss_fun: str or callable, default='square_loss'
Loss function for structure preservation. Options: 'square_loss', 'kl_loss'
or custom function with signature loss_fun(C1, C2, T).
- alpha: float, default=0.5
Step size parameter for the gradient descent algorithm.
- armijo: bool, default=False
Use Armijo line search for adaptive step size.
- log: bool, default=False
Return optimization log with convergence details.
- max_iter: int, default=1000
Maximum number of iterations.
- tol_rel: float, default=1e-9
Relative tolerance for convergence.
- tol_abs: float, default=1e-9
Absolute tolerance for convergence.
Returns:
- transport_plan: ndarray, shape (n1, n2)
Optimal GW transport plan between the two spaces.
- log: dict (if log=True)
Contains 'gw_dist': GW distance, 'err': convergence errors,
'T': transport plans at each iteration.
"""
def ot.gromov.gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute Gromov-Wasserstein squared distance (cost only).
More efficient when only the distance value is needed.
Parameters: Same as gromov_wasserstein()
Returns:
- gw_distance: float
Gromov-Wasserstein distance between the two spaces.
- log: dict (if log=True)
"""
def ot.gromov.solve_gromov_linesearch(C1, C2, p, q, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
"""
Solve GW problem with automatic line search for optimal step size.
Parameters:
- C1, C2: array-like
Cost matrices for source and target spaces.
- p, q: array-like
Distributions over source and target spaces.
- loss_fun: str or callable
Loss function for GW computation.
- alpha_min: float, optional
Minimum step size for line search.
- alpha_max: float, optional
Maximum step size for line search.
- log: bool, default=False
- numItermax: int, default=1000
- stopThr: float, default=1e-9
- verbose: bool, default=False
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""def ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute Fused Gromov-Wasserstein distance combining structure and features.
Combines standard optimal transport (based on feature cost M) with
Gromov-Wasserstein transport (based on structure costs C1, C2).
Parameters:
- M: array-like, shape (n1, n2)
Feature cost matrix between source and target samples.
- C1: array-like, shape (n1, n1)
Intra-structure cost matrix for source space.
- C2: array-like, shape (n2, n2)
Intra-structure cost matrix for target space.
- p: array-like, shape (n1,)
Source distribution.
- q: array-like, shape (n2,)
Target distribution.
- loss_fun: str or callable, default='square_loss'
Loss function for structure preservation.
- alpha: float, default=0.5
Trade-off parameter between structure (α) and features (1-α).
α=1 gives pure GW, α=0 gives pure Wasserstein.
- armijo: bool, default=False
Use Armijo line search.
- log: bool, default=False
- max_iter: int, default=1000
- tol_rel: float, default=1e-9
- tol_abs: float, default=1e-9
Returns:
- transport_plan: ndarray, shape (n1, n2)
Optimal FGW transport plan.
- log: dict (if log=True)
"""
def ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute Fused Gromov-Wasserstein squared distance (cost only).
Parameters: Same as fused_gromov_wasserstein()
Returns:
- fgw_distance: float
- log: dict (if log=True)
"""def ot.gromov.gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None, **kwargs):
"""
Compute Gromov-Wasserstein barycenter of multiple metric spaces.
Finds the barycenter space that minimizes the sum of GW distances
to all input spaces, optimizing both the structure and distribution.
Parameters:
- N: int
Size of the barycenter space (number of points).
- Cs: list of arrays
List of intra-structure cost matrices for input spaces.
Each Cs[i] has shape (ni, ni).
- ps: list of arrays
List of distributions for input spaces.
Each ps[i] has shape (ni,).
- p: array-like, shape (N,)
Distribution for the barycenter space.
- lambdas: array-like, shape (n_spaces,)
Weights for combining input spaces in barycenter.
- loss_fun: str or callable, default='square_loss'
Loss function for GW computation.
- max_iter: int, default=1000
Maximum iterations for barycenter optimization.
- tol: float, default=1e-9
Convergence tolerance.
- verbose: bool, default=False
Print optimization information.
- log: bool, default=False
Return optimization log.
- init_C: array-like, shape (N, N), optional
Initial barycenter structure matrix. Random if None.
- random_state: int, optional
Random seed for reproducible initialization.
Returns:
- barycenter_structure: ndarray, shape (N, N)
Optimal barycenter intra-structure cost matrix.
- log: dict (if log=True)
Contains convergence information and transport plans.
"""
def ot.gromov.fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
"""
Compute Fused Gromov-Wasserstein barycenter with features and structure.
Parameters:
- N: int
Barycenter size.
- Ys: list of arrays
List of feature matrices for input spaces.
- Cs: list of arrays
List of structure matrices for input spaces.
- ps: list of arrays
List of distributions for input spaces.
- lambdas: array-like
Weights for space combination.
- alpha: float
Trade-off between structure and features.
- fixed_structure: bool, default=False
Whether to fix the barycenter structure.
- fixed_features: bool, default=False
Whether to fix the barycenter features.
- p: array-like, optional
Barycenter distribution.
- loss_fun: str or callable, default='square_loss'
- max_iter: int, default=100
- tol: float, default=1e-9
- verbose: bool, default=False
- log: bool, default=False
- init_C: array-like, optional
Initial barycenter structure.
- init_X: array-like, optional
Initial barycenter features.
- random_state: int, optional
Returns:
- barycenter_features: ndarray, shape (N, d)
- barycenter_structure: ndarray, shape (N, N)
- log: dict (if log=True)
"""def ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
"""
Compute entropic regularized Gromov-Wasserstein distance.
Combines GW formulation with entropic regularization for better
computational properties and differentiability.
Parameters:
- C1: array-like, shape (n1, n1)
Source structure matrix.
- C2: array-like, shape (n2, n2)
Target structure matrix.
- p: array-like, shape (n1,)
Source distribution.
- q: array-like, shape (n2,)
Target distribution.
- loss_fun: str or callable, default='square_loss'
- epsilon: float, default=0.1
Entropic regularization parameter.
- symmetric: bool, optional
Whether loss function is symmetric.
- G0: array-like, optional
Initial transport plan.
- max_iter: int, default=1000
- tol: float, default=1e-9
- verbose: bool, default=False
- log: bool, default=False
Returns:
- transport_plan: ndarray, shape (n1, n2)
- log: dict (if log=True)
"""
def ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False, **kwargs):
"""
Compute entropic regularized GW distance (cost only).
Parameters: Same as entropic_gromov_wasserstein()
Returns:
- gw_distance: float
- log: dict (if log=True)
"""
def ot.gromov.entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun='square_loss', epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
"""
Compute entropic regularized GW barycenters.
Parameters:
- N: int
Barycenter size.
- Cs: list of arrays
Structure matrices.
- ps: list of arrays
Input distributions.
- p: array-like
Barycenter distribution.
- lambdas: array-like
Combination weights.
- loss_fun: str or callable, default='square_loss'
- epsilon: float, default=0.1
Entropic regularization.
- symmetric: bool, default=True
- max_iter: int, default=1000
- tol: float, default=1e-9
- verbose: bool, default=False
- log: bool, default=False
- init_C: array-like, optional
- random_state: int, optional
Returns:
- barycenter_structure: ndarray, shape (N, N)
- log: dict (if log=True)
"""
def ot.gromov.entropic_fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Compute entropic regularized Fused GW distance.
Parameters:
- M: array-like, shape (n1, n2)
Feature cost matrix.
- C1, C2: array-like
Structure matrices.
- p, q: array-like
Distributions.
- loss_fun: str or callable, default='square_loss'
- epsilon: float, default=0.1
Entropic regularization.
- alpha: float, default=0.5
Structure/feature trade-off.
- symmetric: bool, optional
- G0: array-like, optional
Initial transport plan.
- max_iter: int, default=1000
- tol: float, default=1e-9
- verbose: bool, default=False
- log: bool, default=False
Returns:
- transport_plan: ndarray, shape (n1, n2)
- log: dict (if log=True)
"""
def ot.gromov.entropic_fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Compute entropic regularized FGW distance (cost only).
Parameters: Same as entropic_fused_gromov_wasserstein()
Returns:
- fgw_distance: float
- log: dict (if log=True)
"""
def ot.gromov.entropic_fused_gromov_barycenters(N, Ys, Cs, ps, lambdas, alpha, epsilon=0.1, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None):
"""
Compute entropic regularized FGW barycenters.
Parameters: Similar to fgw_barycenters() with additional epsilon parameter
Returns:
- barycenter_features: ndarray
- barycenter_structure: ndarray
- log: dict (if log=True)
"""def ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute semi-relaxed Gromov-Wasserstein distance.
Relaxes the constraint on one marginal, allowing for partial transport
from source to target while preserving target marginal.
Parameters:
- C1: array-like, shape (n1, n1)
Source structure matrix.
- C2: array-like, shape (n2, n2)
Target structure matrix.
- p: array-like, shape (n1,)
Source distribution (will be relaxed).
- loss_fun: str or callable, default='square_loss'
- symmetric: bool, optional
- alpha: float, default=0.5
Step size parameter.
- G0: array-like, optional
Initial transport plan.
- log: bool, default=False
- max_iter: int, default=1000
- tol_rel: float, default=1e-9
- tol_abs: float, default=1e-9
Returns:
- transport_plan: ndarray, shape (n1, n2)
Semi-relaxed transport plan.
- log: dict (if log=True)
"""
def ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute semi-relaxed GW distance (cost only).
Parameters: Same as semirelaxed_gromov_wasserstein()
Returns:
- sr_gw_distance: float
- log: dict (if log=True)
"""
def ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute semi-relaxed Fused GW distance.
Parameters:
- M: array-like, shape (n1, n2)
Feature cost matrix.
- Other parameters same as semirelaxed_gromov_wasserstein()
Returns:
- transport_plan: ndarray, shape (n1, n2)
- log: dict (if log=True)
"""
def ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute semi-relaxed FGW distance (cost only).
Parameters: Same as semirelaxed_fused_gromov_wasserstein()
Returns:
- sr_fgw_distance: float
- log: dict (if log=True)
"""
def ot.gromov.solve_semirelaxed_gromov_linesearch(C1, C2, p, loss_fun, alpha_min=None, alpha_max=None, log=False, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
"""
Solve semi-relaxed GW with line search optimization.
Parameters: Similar to solve_gromov_linesearch()
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""def ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute partial Gromov-Wasserstein distance.
Allows transport of only a fraction of the total mass, useful when
spaces have different sizes or contain outliers.
Parameters:
- C1: array-like, shape (n1, n1)
Source structure matrix.
- C2: array-like, shape (n2, n2)
Target structure matrix.
- p: array-like, shape (n1,)
Source distribution.
- q: array-like, shape (n2,)
Target distribution.
- m: float, optional
Fraction of mass to transport (default: min(sum(p), sum(q))).
- loss_fun: str or callable, default='square_loss'
- alpha: float, default=0.5
- armijo: bool, default=False
- log: bool, default=False
- max_iter: int, default=1000
- tol_rel: float, default=1e-9
- tol_abs: float, default=1e-9
Returns:
- transport_plan: ndarray, shape (n1, n2)
Partial GW transport plan.
- log: dict (if log=True)
"""
def ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, max_iter=1000, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
"""
Compute partial GW distance (cost only).
Parameters: Same as partial_gromov_wasserstein()
Returns:
- partial_gw_distance: float
- log: dict (if log=True)
"""
def ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Compute entropic regularized partial GW distance.
Parameters:
- C1, C2: array-like
Structure matrices.
- p, q: array-like
Distributions.
- reg: float
Entropic regularization parameter.
- m: float, optional
Mass to transport.
- loss_fun: str or callable, default='square_loss'
- G0: array-like, optional
- max_iter: int, default=1000
- tol: float, default=1e-9
- verbose: bool, default=False
- log: bool, default=False
Returns:
- transport_plan: ndarray
- log: dict (if log=True)
"""
def ot.gromov.entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, loss_fun='square_loss', G0=None, max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Compute entropic regularized partial GW distance (cost only).
Parameters: Same as entropic_partial_gromov_wasserstein()
Returns:
- partial_gw_distance: float
- log: dict (if log=True)
"""def ot.gromov.gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0.0, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None, **kwargs):
"""
Learn dictionary of structures using GW distances.
Learns a dictionary of prototype structures that can represent
input structures as sparse combinations.
Parameters:
- Cs: list of arrays
Input structure matrices to learn from.
- D: int
Dictionary size (number of atoms).
- nt: int
Size of each dictionary atom.
- reg: float, default=0.0
Entropic regularization for GW computation.
- ps: list of arrays, optional
Distributions for input structures.
- q: array-like, optional
Distribution for dictionary atoms.
- epochs: int, default=20
Number of learning epochs.
- batch_size: int, default=32
Mini-batch size for learning.
- learning_rate: float, default=1.0
Learning rate for dictionary updates.
- proj_sparse_regul: float, default=0.1
Sparsity regularization for projections.
- verbose: bool, default=False
- random_state: int, optional
Returns:
- dictionary: list of arrays
Learned dictionary of structure matrices.
- log: dict
Learning statistics and convergence information.
"""
def ot.gromov.gromov_wasserstein_linear_unmixing(C, Cdict, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False, **kwargs):
"""
Unmix structure using learned GW dictionary.
Decomposes input structure as sparse combination of dictionary atoms.
Parameters:
- C: array-like, shape (n, n)
Structure matrix to unmix.
- Cdict: list of arrays
Dictionary of structure atoms.
- reg: float, default=0.0
Entropic regularization.
- p: array-like, optional
Distribution for input structure.
- q: array-like, optional
Distribution for dictionary atoms.
- tol_outer: float, default=1e-6
Outer loop tolerance.
- tol_inner: float, default=1e-6
Inner loop tolerance.
- max_iter_outer: int, default=20
- max_iter_inner: int, default=200
- verbose: bool, default=False
Returns:
- coefficients: ndarray
Sparse coefficients for dictionary combination.
- log: dict
Unmixing optimization information.
"""
def ot.gromov.fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, reg=0.0, alpha=0.5, ps=None, q=None, epochs=20, batch_size=32, learning_rate=1.0, proj_sparse_regul=0.1, verbose=False, random_state=None):
"""
Learn dictionary for FGW (structure + features).
Parameters: Extends gromov_wasserstein_dictionary_learning() with:
- Ys: list of arrays
Feature matrices for input data.
- alpha: float, default=0.5
Structure/feature trade-off.
Returns:
- structure_dictionary: list of arrays
- feature_dictionary: list of arrays
- log: dict
"""
def ot.gromov.fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha=0.5, reg=0.0, p=None, q=None, tol_outer=1e-6, tol_inner=1e-6, max_iter_outer=20, max_iter_inner=200, verbose=False):
"""
Unmix FGW data using learned dictionary.
Parameters: Extends gromov_wasserstein_linear_unmixing() with:
- Y: array-like
Feature matrix to unmix.
- Ydict: list of arrays
Feature dictionary atoms.
- alpha: float, default=0.5
Returns:
- coefficients: ndarray
- log: dict
"""def ot.gromov.init_matrix(C1, C2, p, q, loss_fun='square_loss', random_state=None):
"""
Initialize transport matrix for GW algorithms.
Parameters:
- C1: array-like, shape (n1, n1)
- C2: array-like, shape (n2, n2)
- p: array-like, shape (n1,)
- q: array-like, shape (n2,)
- loss_fun: str or callable, default='square_loss'
- random_state: int, optional
Returns:
- G0: ndarray, shape (n1, n2)
Initial transport matrix.
"""
def ot.gromov.tensor_product(constC, hC1, hC2, T):
"""
Compute tensor product for GW gradient computation.
Parameters:
- constC: ndarray
Constant term in GW formulation.
- hC1: ndarray
Source structure term.
- hC2: ndarray
Target structure term.
- T: ndarray
Current transport plan.
Returns:
- tensor_prod: ndarray
Tensor product result.
"""
def ot.gromov.gwloss(constC, hC1, hC2, T):
"""
Compute Gromov-Wasserstein loss function value.
Parameters:
- constC: ndarray
- hC1: ndarray
- hC2: ndarray
- T: ndarray
Transport plan.
Returns:
- loss: float
GW loss value.
"""
def ot.gromov.gwggrad(constC, hC1, hC2, T):
"""
Compute Gromov-Wasserstein gradient.
Parameters:
- constC: ndarray
- hC1: ndarray
- hC2: ndarray
- T: ndarray
Returns:
- gradient: ndarray
GW objective gradient.
"""
def ot.gromov.update_barycenter_structure(Ts, Cs, lambdas, p, loss_fun='square_loss'):
"""
Update barycenter structure matrix.
Parameters:
- Ts: list of arrays
Transport plans to input spaces.
- Cs: list of arrays
Input structure matrices.
- lambdas: array-like
Barycenter weights.
- p: array-like
Barycenter distribution.
- loss_fun: str or callable, default='square_loss'
Returns:
- C_barycenter: ndarray
Updated barycenter structure.
"""
def ot.gromov.update_barycenter_feature(Ts, Ys, lambdas, p):
"""
Update barycenter feature matrix.
Parameters:
- Ts: list of arrays
Transport plans.
- Ys: list of arrays
Input feature matrices.
- lambdas: array-like
- p: array-like
Returns:
- Y_barycenter: ndarray
Updated barycenter features.
"""import ot
import numpy as np
# Create structure matrices (e.g., distance matrices)
n1, n2 = 10, 15
C1 = np.random.rand(n1, n1)
C1 = (C1 + C1.T) / 2 # Make symmetric
C2 = np.random.rand(n2, n2)
C2 = (C2 + C2.T) / 2
# Create distributions
p = ot.unif(n1)
q = ot.unif(n2)
# Compute GW distance
gw_plan = ot.gromov.gromov_wasserstein(C1, C2, p, q, verbose=True)
gw_dist = ot.gromov.gromov_wasserstein2(C1, C2, p, q)
print(f"GW distance: {gw_dist}")
print(f"Transport plan shape: {gw_plan.shape}")# Feature cost matrix
d = 3
X1 = np.random.randn(n1, d)
X2 = np.random.randn(n2, d)
M = ot.dist(X1, X2)
# Structure-feature trade-off
alpha = 0.7 # More weight on structure
# Compute FGW
fgw_plan = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, alpha=alpha)
fgw_dist = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, alpha=alpha)
print(f"FGW distance: {fgw_dist}")# Multiple structures
n_spaces = 5
Cs = [np.random.rand(8, 8) for _ in range(n_spaces)]
Cs = [(C + C.T)/2 for C in Cs] # Make symmetric
ps = [ot.unif(8) for _ in range(n_spaces)]
lambdas = ot.unif(n_spaces)
# Barycenter parameters
N = 6 # Barycenter size
p_barycenter = ot.unif(N)
# Compute barycenter
C_barycenter = ot.gromov.gromov_barycenters(N, Cs, ps, p_barycenter, lambdas, verbose=True)
print(f"Barycenter structure shape: {C_barycenter.shape}")# Add entropic regularization
epsilon = 0.05
# Compute entropic GW
egw_plan = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, epsilon=epsilon)
egw_dist = ot.gromov.entropic_gromov_wasserstein2(C1, C2, p, q, epsilon=epsilon)
print(f"Entropic GW distance: {egw_dist}")# Transport only 70% of mass
m = 0.7
# Compute partial GW
pgw_plan = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m)
pgw_dist = ot.gromov.partial_gromov_wasserstein2(C1, C2, p, q, m=m)
print(f"Partial GW distance: {pgw_dist}")
print(f"Transported mass: {np.sum(pgw_plan)}")Large-scale methods using graph partitioning, quantization, and sampling approaches.
def quantized_fused_gromov_wasserstein(C1, C2, Y1, Y2, a=None, b=None, alpha=0.5, reg=0.1, num_node_class=8, **kwargs):
"""
Solve quantized FGW using graph partitioning for computational efficiency.
Parameters:
- C1, C2: array-like, structure matrices
- Y1, Y2: array-like, feature matrices
- a, b: array-like, distributions
- alpha: float, structure/feature weight
- reg: float, regularization parameter
- num_node_class: int, number of partitions
Returns:
- quantized transport plan
"""
def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0.0, rank=10, numItermax=100, stopThr=1e-5, log=False):
"""
Solve GW using low-rank factorization for large-scale problems.
"""
def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', nb_samples_grad=100, log=False, **kwargs):
"""
Solve GW using sampling for gradient computation.
"""
def get_graph_partition(C1, num_node_class=8, part_method='louvain'):
"""
Partition graph for quantized methods.
"""Unbalanced variants allowing different total masses.
def fused_unbalanced_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, alpha=0.5, rho=1.0, rho2=1.0, **kwargs):
"""
Solve unbalanced FGW with marginal relaxation penalties.
Parameters:
- M: array-like, feature cost matrix
- C1, C2: array-like, structure matrices
- p, q: array-like, measures (can have different masses)
- epsilon: float, entropic regularization
- alpha: float, structure/feature trade-off
- rho, rho2: float, marginal relaxation penalties
Returns:
- unbalanced transport plan
"""
def unbalanced_co_optimal_transport(X_s, X_t, C1, C2, p, q, epsilon=0.1, rho=1.0, rho2=1.0, **kwargs):
"""
Solve unbalanced co-optimal transport.
"""import ot.gromov
from ot.gromov import gromov_wasserstein, gromov_wasserstein2
from ot.gromov import fused_gromov_wasserstein, fused_gromov_wasserstein2
from ot.gromov import gromov_barycenters, fgw_barycenters
from ot.gromov import entropic_gromov_wasserstein, entropic_fused_gromov_wasserstein
from ot.gromov import semirelaxed_gromov_wasserstein, partial_gromov_wasserstein
from ot.gromov import gromov_wasserstein_dictionary_learning, quantized_fused_gromov_wassersteinThe ot.gromov module provides powerful tools for structured optimal transport, enabling comparison of data with internal geometric structure such as graphs, point clouds, and other metric spaces where traditional optimal transport is not directly applicable.
Install with Tessl CLI
npx tessl i tessl/pypi-potdocs