CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

utilities.mddocs/

Utility Functions and Tools

The ot.utils and ot.datasets modules provide essential utility functions and data generation tools that support optimal transport computations. These include distance calculations, distribution generators, timing functions, array manipulations, and synthetic datasets for testing and benchmarking.

Timing Functions

def ot.utils.tic():
    """
    Start timer for performance measurement.
    
    Initializes a global timer to measure elapsed time for code execution.
    Use in combination with toc() or toq() for timing code blocks.
    
    Example:
        ot.tic()
        # ... code to time ...
        elapsed = ot.toq()
    """

def ot.utils.toc(message="Elapsed time : {} s"):
    """
    End timer and print elapsed time with custom message.
    
    Prints the elapsed time since the last tic() call with a customizable
    message format.
    
    Parameters:
    - message: str, default="Elapsed time : {} s"
         Format string for the elapsed time message. Should contain {} placeholder
         for the time value.
    
    Example:
        ot.tic()
        # ... computation ...
        ot.toc("Computation took: {:.3f} seconds")
    """

def ot.utils.toq():
    """
    End timer and return elapsed time without printing.
    
    Returns the elapsed time since the last tic() call as a float value
    without printing any message.
    
    Returns:
    - elapsed_time: float
         Elapsed time in seconds.
    
    Example:
        ot.tic()
        result = expensive_computation()
        time_taken = ot.toq()
        print(f"Computation took {time_taken:.2f} seconds")
    """

Distribution Functions

def ot.utils.unif(n, type_as=None):
    """
    Generate uniform distribution over n points.
    
    Creates a uniform probability distribution (histogram) with equal mass
    on each of n support points.
    
    Parameters:
    - n: int
         Number of points in the distribution.
    - type_as: array-like, optional
         Reference array to determine the output array type and backend.
         If None, returns numpy array.
    
    Returns:
    - distribution: ndarray, shape (n,)
         Uniform distribution with each entry equal to 1/n.
    
    Example:
        uniform_dist = ot.unif(5)  # [0.2, 0.2, 0.2, 0.2, 0.2]
    """

def ot.utils.clean_zeros(a, b, M):
    """
    Remove zero entries from distributions and corresponding cost matrix entries.
    
    Filters out zero-weight points from source and target distributions
    and removes corresponding rows/columns from the cost matrix to avoid
    numerical issues and reduce computation.
    
    Parameters:
    - a: array-like, shape (n_source,)
         Source distribution (may contain zeros).
    - b: array-like, shape (n_target,)
         Target distribution (may contain zeros).
    - M: array-like, shape (n_source, n_target)
         Cost matrix.
    
    Returns:
    - a_clean: ndarray
         Source distribution with zeros removed.
    - b_clean: ndarray
         Target distribution with zeros removed.
    - M_clean: ndarray
         Cost matrix with corresponding rows/columns removed.
    
    Example:
        a = [0.5, 0.0, 0.5]
        b = [0.3, 0.7]
        M = [[1, 2], [3, 4], [5, 6]]
        a_clean, b_clean, M_clean = ot.utils.clean_zeros(a, b, M)
        # Returns: [0.5, 0.5], [0.3, 0.7], [[1, 2], [5, 6]]
    """

Distance Functions

def ot.utils.dist(x1, x2=None, metric='sqeuclidean'):
    """
    Compute distance matrix between sample sets.
    
    Computes pairwise distances between points in x1 and x2 using the
    specified metric. This is the primary function for generating cost
    matrices from sample coordinates.
    
    Parameters:
    - x1: array-like, shape (n1, d)
         First set of samples (source points).
    - x2: array-like, shape (n2, d), optional
         Second set of samples (target points). If None, computes distances
         within x1 (i.e., x2 = x1).
    - metric: str, default='sqeuclidean'
         Distance metric to use. Options include:
         'sqeuclidean', 'euclidean', 'cityblock', 'cosine', 'correlation',
         'hamming', 'jaccard', 'chebyshev', 'minkowski', 'mahalanobis'
    
    Returns:
    - distance_matrix: ndarray, shape (n1, n2)
         Matrix of pairwise distances. Entry (i,j) is the distance between
         x1[i] and x2[j].
    
    Example:
        X1 = np.array([[0, 0], [1, 1]])
        X2 = np.array([[0, 1], [1, 0]])
        M = ot.dist(X1, X2)  # [[1, 1], [1, 1]]
    """

def ot.utils.euclidean_distances(X, Y, squared=False):
    """
    Compute Euclidean distances between samples.
    
    Efficient computation of Euclidean distances with option for squared distances.
    
    Parameters:
    - X: array-like, shape (n_samples_X, n_features)
         First sample set.
    - Y: array-like, shape (n_samples_Y, n_features)
         Second sample set.
    - squared: bool, default=False
         If True, return squared Euclidean distances.
    
    Returns:
    - distances: ndarray, shape (n_samples_X, n_samples_Y)
         Euclidean distance matrix.
    """

def ot.utils.dist0(n, method='lin_square'):
    """
    Generate ground cost matrix for n points on a grid.
    
    Creates standard cost matrices for points arranged on 1D or 2D grids,
    commonly used for image processing and discrete optimal transport.
    
    Parameters:
    - n: int
         Number of points (for 1D) or side length (for 2D grid).
    - method: str, default='lin_square'
         Grid arrangement and distance metric. Options:
         'lin_square': 1D grid with squared distances
         'lin': 1D grid with linear distances
         'square': 2D square grid
    
    Returns:
    - cost_matrix: ndarray, shape (n, n) or (n*n, n*n)
         Ground cost matrix for the specified grid arrangement.
    
    Example:
        M = ot.utils.dist0(3, method='lin_square')
        # Returns 3x3 matrix with squared distances on 1D line
    """

Projection Functions

def ot.utils.proj_simplex(v, z=1):
    """
    Projection onto the probability simplex.
    
    Projects a vector onto the probability simplex: {x : x_i >= 0, sum(x) = z}.
    Essential for many optimization algorithms in optimal transport.
    
    Parameters:
    - v: array-like, shape (n,)
         Input vector to project.
    - z: float, default=1
         Sum constraint for the simplex.
    
    Returns:
    - projected_vector: ndarray, shape (n,)
         Projection of v onto the simplex.
    
    Example:
        v = np.array([2.0, -1.0, 3.0])
        p = ot.utils.proj_simplex(v)  # Projects to valid probability distribution
    """

def ot.utils.projection_sparse_simplex(V, max_nz, z=1):
    """
    Projection onto sparse simplex with cardinality constraint.
    
    Projects onto the intersection of probability simplex and sparsity constraint
    (at most max_nz non-zero entries).
    
    Parameters:
    - V: array-like, shape (n,)
         Input vector.
    - max_nz: int
         Maximum number of non-zero entries.
    - z: float, default=1
         Sum constraint.
    
    Returns:
    - projected_vector: ndarray, shape (n,)
         Sparse simplex projection.
    """

def ot.utils.proj_SDP(S, nx=None, vmin=0.0):
    """
    Projection onto positive semidefinite cone.
    
    Projects a symmetric matrix onto the cone of positive semidefinite matrices
    by eigendecomposition and thresholding negative eigenvalues.
    
    Parameters:
    - S: array-like, shape (n, n)
         Symmetric matrix to project.
    - nx: backend, optional
         Numerical backend to use.
    - vmin: float, default=0.0
         Minimum eigenvalue threshold.
    
    Returns:
    - S_projected: ndarray, shape (n, n)
         Positive semidefinite projection of S.
    """

Array Manipulation Functions

def ot.utils.list_to_array(*lst, nx=None):
    """
    Convert lists or mixed types to arrays with consistent backend.
    
    Standardizes input data to arrays using the specified backend,
    handling mixed input types and ensuring compatibility.
    
    Parameters:
    - lst: sequence of array-like objects
         Input data to convert to arrays.
    - nx: backend, optional
         Target backend for conversion.
    
    Returns:
    - arrays: tuple of ndarrays
         Converted arrays in the target backend format.
    """

def ot.utils.cost_normalization(C, norm=None, nx=None):
    """
    Normalize cost matrix using various normalization schemes.
    
    Applies normalization to cost matrices to improve numerical stability
    and algorithm convergence.
    
    Parameters:
    - C: array-like, shape (n, m)
         Cost matrix to normalize.
    - norm: str, optional
         Normalization method. Options: 'median', 'max', 'log', 'loglog'
    - nx: backend, optional
         Numerical backend.
    
    Returns:
    - C_normalized: ndarray
         Normalized cost matrix.
    """

def ot.utils.dots(*args):
    """
    Compute chained dot products efficiently.
    
    Computes the dot product of multiple matrices in the optimal order
    to minimize computational cost.
    
    Parameters:
    - args: sequence of arrays
         Matrices to multiply in sequence.
    
    Returns:
    - result: ndarray
         Result of chained matrix multiplication.
    
    Example:
        A, B, C = random_matrices()
        result = ot.utils.dots(A, B, C)  # Equivalent to A @ B @ C
    """

def ot.utils.is_all_finite(*args):
    """
    Check if all elements in arrays are finite.
    
    Validates that arrays contain only finite values (no NaN or infinity),
    useful for debugging numerical issues.
    
    Parameters:
    - args: sequence of arrays
         Arrays to check.
    
    Returns:
    - all_finite: bool
         True if all elements in all arrays are finite.
    """

Label Processing Functions

def ot.utils.label_normalization(y, start=0, nx=None):
    """
    Normalize label array to consecutive integers starting from specified value.
    
    Converts arbitrary label values to normalized consecutive integers,
    useful for domain adaptation and classification tasks.
    
    Parameters:
    - y: array-like, shape (n,)
         Input labels (can be strings, integers, etc.).
    - start: int, default=0
         Starting value for normalized labels.
    - nx: backend, optional
         Numerical backend for array operations.
    
    Returns:
    - y_normalized: ndarray, shape (n,)
         Normalized integer labels starting from 'start'.
    - unique_labels: list
         Original unique label values in order.
    
    Example:
        y = ['cat', 'dog', 'cat', 'bird']
        y_norm, labels = ot.utils.label_normalization(y)
        # y_norm: [0, 1, 0, 2], labels: ['cat', 'dog', 'bird']
    """

def ot.utils.labels_to_masks(y, type_as=None, nx=None):
    """
    Convert label array to binary mask matrix.
    
    Creates one-hot encoded masks from categorical labels, where each column
    corresponds to one class.
    
    Parameters:
    - y: array-like, shape (n,)
         Integer labels.
    - type_as: array-like, optional
         Reference array for output type.
    - nx: backend, optional
         Numerical backend.
    
    Returns:
    - masks: ndarray, shape (n, n_classes)
         Binary mask matrix where masks[i, j] = 1 if y[i] == j.
    
    Example:
        y = [0, 1, 0, 2]
        masks = ot.utils.labels_to_masks(y)
        # masks: [[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]
    """

Geometric and Kernel Functions

def ot.utils.kernel(x1, x2, method='gaussian', sigma=1.0):
    """
    Compute kernel matrix between sample sets.
    
    Generates kernel matrices for various kernel functions, useful for
    kernel-based optimal transport methods.
    
    Parameters:
    - x1: array-like, shape (n1, d)
         First sample set.
    - x2: array-like, shape (n2, d)
         Second sample set.
    - method: str, default='gaussian'
         Kernel type. Options: 'gaussian', 'linear', 'polynomial'
    - sigma: float, default=1.0
         Kernel bandwidth parameter (for Gaussian kernel).
    
    Returns:
    - kernel_matrix: ndarray, shape (n1, n2)
         Kernel values between samples.
    """

def ot.utils.laplacian(x):
    """
    Compute graph Laplacian matrix.
    
    Constructs the graph Laplacian for samples, used in graph-based
    optimal transport and manifold learning.
    
    Parameters:
    - x: array-like, shape (n, d)
         Sample coordinates.
    
    Returns:
    - laplacian: ndarray, shape (n, n)
         Graph Laplacian matrix.
    """

def ot.utils.get_coordinate_circle(x):
    """
    Get coordinates on unit circle for circular optimal transport.
    
    Maps 1D coordinates to points on the unit circle, used for
    circular/periodic optimal transport problems.
    
    Parameters:
    - x: array-like, shape (n,)
         1D coordinates (angles).
    
    Returns:
    - circle_coords: ndarray, shape (n, 2)
         2D coordinates on unit circle.
    """

Parallel and Random Utilities

def ot.utils.parmap(f, X, nprocs='default'):
    """
    Parallel map function for multiprocessing.
    
    Applies function f to elements of X in parallel using multiple processes.
    
    Parameters:
    - f: callable
         Function to apply to each element.
    - X: iterable
         Input data to process.
    - nprocs: int or 'default'
         Number of processes. If 'default', uses all available cores.
    
    Returns:
    - results: list
         Results of applying f to each element of X.
    """

def ot.utils.check_random_state(seed):
    """
    Validate and convert random seed to RandomState object.
    
    Ensures consistent random number generation across different input types.
    
    Parameters:
    - seed: int, RandomState, or None
         Random seed specification.
    
    Returns:
    - random_state: numpy.random.RandomState
         Validated random state object.
    """

def ot.utils.check_params(**kwargs):
    """
    Validate function parameters and provide defaults.
    
    Generic parameter validation utility for POT functions.
    
    Parameters:
    - kwargs: dict
         Parameter dictionary to validate.
    
    Returns:
    - validated_params: dict
         Validated parameters with defaults filled in.
    """

Backend Utilities

def ot.utils.reduce_lazytensor(a, func, dim=None, **kwargs):
    """
    Reduce lazy tensor along specified dimensions.
    
    Efficient reduction operations for lazy tensor backends like KeOps.
    
    Parameters:
    - a: LazyTensor
         Input lazy tensor.
    - func: str
         Reduction function ('sum', 'max', 'min', etc.).
    - dim: int, optional
         Dimension along which to reduce.
    - kwargs: dict
         Additional arguments for reduction.
    
    Returns:
    - result: array
         Result of reduction operation.
    """

def ot.utils.get_lowrank_lazytensor(Q, R, X, Y):
    """
    Create low-rank lazy tensor representation.
    
    Constructs efficient lazy tensor for low-rank matrix operations.
    
    Parameters:
    - Q: array-like
         Left factor matrix.
    - R: array-like
         Right factor matrix.
    - X: array-like
         Source coordinates.
    - Y: array-like
         Target coordinates.
    
    Returns:
    - lazy_tensor: LazyTensor
         Low-rank lazy tensor representation.
    """

def ot.utils.get_parameter_pair(parameter):
    """
    Convert single parameter to parameter pair for source/target.
    
    Utility for handling parameters that can be specified as single values
    or pairs for source and target separately.
    
    Parameters:
    - parameter: float or tuple
         Parameter value(s).
    
    Returns:
    - param_source: float
    - param_target: float
    """

Dataset Generation (ot.datasets)

def ot.datasets.make_1D_gauss(n, m, s):
    """
    Generate 1D Gaussian histogram.
    
    Creates a discrete 1D Gaussian distribution on a regular grid.
    
    Parameters:
    - n: int
         Number of bins/points in the histogram.
    - m: float
         Mean of the Gaussian distribution.
    - s: float
         Standard deviation of the Gaussian.
    
    Returns:
    - histogram: ndarray, shape (n,)
         Normalized 1D Gaussian histogram.
    - x: ndarray, shape (n,)
         Bin centers (x-coordinates).
    
    Example:
        hist, x = ot.datasets.make_1D_gauss(100, 0.5, 0.1)
    """

def ot.datasets.make_2D_samples_gauss(n, m, sigma, random_state=None):
    """
    Generate 2D Gaussian samples.
    
    Creates n samples from a 2D Gaussian distribution with specified
    mean and covariance matrix.
    
    Parameters:
    - n: int
         Number of samples to generate.
    - m: array-like, shape (2,)
         Mean vector of the Gaussian.
    - sigma: array-like, shape (2, 2)
         Covariance matrix of the Gaussian.
    - random_state: int, optional
         Random seed for reproducibility.
    
    Returns:
    - samples: ndarray, shape (n, 2)
         Generated 2D Gaussian samples.
    
    Example:
        mean = [0, 0]
        cov = [[1, 0.5], [0.5, 1]]
        X = ot.datasets.make_2D_samples_gauss(1000, mean, cov)
    """

def ot.datasets.make_data_classif(dataset, n, nz=0.5, theta=0, p=0.5, random_state=None, **kwargs):
    """
    Generate classification datasets for domain adaptation.
    
    Creates synthetic datasets commonly used for testing domain adaptation
    algorithms with optimal transport.
    
    Parameters:
    - dataset: str
         Dataset type. Options: 'gaussians', 'moons', 'circles'
    - n: int
         Number of samples per class.
    - nz: float, default=0.5
         Noise level.
    - theta: float, default=0
         Rotation angle for domain shift.
    - p: float, default=0.5
         Proportion parameter.
    - random_state: int, optional
         Random seed.
    - kwargs: dict
         Additional dataset-specific parameters.
    
    Returns:
    - X: ndarray, shape (n_total, n_features)
         Sample coordinates.
    - y: ndarray, shape (n_total,)
         Class labels.
    
    Example:
        X, y = ot.datasets.make_data_classif('moons', 100, nz=0.1)
    """

Usage Examples

Basic Utility Usage

import ot
import numpy as np

# Timing code execution
ot.tic()
result = np.linalg.eig(np.random.rand(1000, 1000))
elapsed = ot.toq()
print(f"Eigendecomposition took {elapsed:.3f} seconds")

# Generate uniform distribution
uniform_dist = ot.unif(10)
print("Uniform distribution:", uniform_dist)

# Compute distance matrix
X = np.random.rand(5, 2)
Y = np.random.rand(3, 2)
distances = ot.dist(X, Y)
print("Distance matrix shape:", distances.shape)

Working with Labels

# Label normalization
labels = ['cat', 'dog', 'cat', 'bird', 'dog']
normalized_labels, unique = ot.utils.label_normalization(labels)
print("Normalized labels:", normalized_labels)
print("Unique labels:", unique)

# Convert to masks
masks = ot.utils.labels_to_masks(normalized_labels)
print("One-hot masks shape:", masks.shape)

Dataset Generation

# 1D Gaussian histogram
hist, x = ot.datasets.make_1D_gauss(50, 0.3, 0.1)
print("1D histogram sum:", np.sum(hist))

# 2D Gaussian samples
mean = [1, -1]
cov = [[0.5, 0.2], [0.2, 0.8]]
samples = ot.datasets.make_2D_samples_gauss(200, mean, cov)
print("2D samples shape:", samples.shape)

# Classification dataset
X_moons, y_moons = ot.datasets.make_data_classif('moons', 100, nz=0.2)
print("Moons dataset:", X_moons.shape, "Classes:", np.unique(y_moons))

Projections and Normalizations

# Simplex projection
v = np.array([2.0, -1.0, 3.0, 0.5])
projected = ot.utils.proj_simplex(v)
print("Original vector:", v)
print("Projected (simplex):", projected)
print("Sum after projection:", np.sum(projected))

# Cost matrix normalization
C = np.random.rand(10, 10) * 100
C_normalized = ot.utils.cost_normalization(C, norm='median')
print("Original cost range:", [np.min(C), np.max(C)])
print("Normalized cost range:", [np.min(C_normalized), np.max(C_normalized)])

The utilities and datasets modules provide the foundational tools needed for most optimal transport applications, from basic array manipulations to specialized dataset generation for research and benchmarking.

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json