CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pot

Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

backend-system.mddocs/

Backend System

The ot.backend module provides a unified interface for multi-framework computation, enabling POT to work seamlessly with NumPy, PyTorch, JAX, TensorFlow, and CuPy. This backend system allows for automatic differentiation, GPU acceleration, and integration with deep learning frameworks.

Core Backend Functions

Backend Discovery and Management

def ot.backend.get_backend(*args):
    """
    Get appropriate backend for input arrays.
    
    Automatically detects the backend type from input arrays and returns
    the corresponding backend instance. Enables framework-agnostic code
    that works across NumPy, PyTorch, JAX, etc.
    
    Parameters:
    - args: sequence of arrays
         Input arrays to determine backend from. The function analyzes
         the type of these arrays to select the appropriate backend.
    
    Returns:
    - backend: Backend instance
         Backend object providing unified array operations interface.
    
    Example:
        import numpy as np
        import torch
        
        # NumPy arrays -> NumPy backend
        a_np = np.array([1, 2, 3])
        nx = ot.backend.get_backend(a_np)
        print(type(nx))  # <class 'ot.backend.NumpyBackend'>
        
        # PyTorch tensors -> PyTorch backend  
        a_torch = torch.tensor([1, 2, 3])
        nx = ot.backend.get_backend(a_torch)
        print(type(nx))  # <class 'ot.backend.TorchBackend'>
    """

def ot.backend.get_backend_list():
    """
    List all available backends in the current environment.
    
    Returns list of backend names that are available based on installed
    packages and environment configuration.
    
    Returns:
    - backends: list of str
         Names of available backends: 'numpy', 'torch', 'jax', 'tf', 'cupy'
    
    Example:
        available = ot.backend.get_backend_list()
        print("Available backends:", available)
        # Output: ['numpy', 'torch', 'jax'] (depending on installation)
    """

def ot.backend.get_available_backend_implementations():
    """
    Get detailed information about available backend implementations.
    
    Returns:
    - implementations: dict
         Dictionary mapping backend names to their implementation details,
         versions, and availability status.
    """

def ot.backend.to_numpy(*args):
    """
    Convert arrays to numpy format regardless of input backend.
    
    Universal converter that transforms arrays from any supported backend
    (PyTorch tensors, JAX arrays, etc.) to NumPy arrays.
    
    Parameters:
    - args: sequence of arrays
         Arrays in any supported backend format.
    
    Returns:
    - numpy_arrays: tuple of numpy.ndarray
         Arrays converted to NumPy format. Returns single array if only
         one input, otherwise returns tuple.
    
    Example:
        import torch
        a_torch = torch.tensor([1.0, 2.0, 3.0])
        b_torch = torch.tensor([4.0, 5.0, 6.0])
        
        a_np, b_np = ot.backend.to_numpy(a_torch, b_torch)
        print(type(a_np))  # <class 'numpy.ndarray'>
    """

Backend Classes

Base Backend Class

class ot.backend.Backend:
    """
    Base backend class defining the unified interface for array operations.
    
    All backend implementations inherit from this class and implement
    the required methods for array manipulation, linear algebra, and
    optimization operations needed by POT algorithms.
    """
    
    def zeros(self, shape, type_as=None):
        """Create array filled with zeros."""
    
    def ones(self, shape, type_as=None):
        """Create array filled with ones."""
    
    def full(self, shape, fill_value, type_as=None):
        """Create array filled with specified value."""
    
    def eye(self, N, M=None, type_as=None):
        """Create identity matrix."""
    
    def sum(self, a, axis=None, keepdims=False):
        """Sum array elements along specified axis."""
    
    def cumsum(self, a, axis=None):
        """Cumulative sum along axis."""
    
    def max(self, a, axis=None, keepdims=False):
        """Maximum values along axis."""
    
    def min(self, a, axis=None, keepdims=False):
        """Minimum values along axis."""
    
    def dot(self, a, b):
        """Matrix multiplication."""
    
    def norm(self, a, axis=None, keepdims=False):
        """Compute norm of array."""
    
    def exp(self, a):
        """Element-wise exponential."""
    
    def log(self, a):
        """Element-wise natural logarithm."""
    
    def sqrt(self, a):
        """Element-wise square root."""
    
    def sort(self, a, axis=-1):
        """Sort array along axis."""
    
    def argsort(self, a, axis=-1):
        """Indices that would sort array."""
    
    def searchsorted(self, a, v, side='left'):
        """Find indices for sorted insertion."""
    
    def flip(self, a, axis=None):
        """Reverse array along axis."""
    
    def clip(self, a, a_min, a_max):
        """Clip array values to range."""
    
    def repeat(self, a, repeats, axis=None):
        """Repeat array elements."""
    
    def take_along_axis(self, a, indices, axis):
        """Take values along axis using indices."""
    
    def concatenate(self, arrays, axis=0):
        """Join arrays along axis."""
    
    def zero_pad(self, a, pad_width):
        """Pad array with zeros."""
    
    def argmax(self, a, axis=None):
        """Indices of maximum values."""
    
    def argmin(self, a, axis=None):
        """Indices of minimum values."""
    
    def mean(self, a, axis=None):
        """Mean along axis."""
    
    def std(self, a, axis=None):
        """Standard deviation along axis."""
    
    def linspace(self, start, stop, num):
        """Create evenly spaced numbers."""
    
    def meshgrid(self, a, b):
        """Create coordinate matrices."""
    
    def diag(self, a, k=0):
        """Extract or create diagonal."""
    
    def unique(self, a):
        """Find unique elements."""
    
    def logsumexp(self, a, axis=None):
        """Log of sum of exponentials."""
    
    def stack(self, arrays, axis=0):
        """Join arrays along new axis."""
    
    def reshape(self, a, shape):
        """Change array shape."""
    
    def seed(self, seed=None):
        """Set random seed."""
    
    def rand(self, *args, **kwargs):
        """Random values in [0,1)."""
    
    def randn(self, *args, **kwargs):
        """Random values from standard normal."""
    
    def coo_matrix(self, S, shape=None):
        """Create sparse COO matrix."""
    
    def issparse(self, a):
        """Check if array is sparse."""
    
    def tocsr(self, a):
        """Convert to CSR sparse format."""
    
    def eliminate_zeros(self, a):
        """Remove explicit zeros from sparse matrix."""
    
    def todense(self, a):
        """Convert sparse to dense."""
    
    def where(self, condition, x=None, y=None):
        """Select elements based on condition."""
    
    def copy(self, a):
        """Create copy of array."""
    
    def allclose(self, a, b, rtol=1e-05, atol=1e-08):
        """Test if arrays are element-wise equal within tolerance."""
    
    def dtype_device(self, a):
        """Get dtype and device info."""
    
    def assert_same_dtype_device(self, a, b):
        """Assert arrays have same dtype and device."""
    
    def squeeze(self, a, axis=None):
        """Remove single-dimensional entries."""
    
    def bitsize(self, type_as):
        """Get bit size of data type."""
    
    def device_type(self, type_as):
        """Get device type (cpu, cuda, etc.)."""
    
    def _bench(self, callable, *args, **kwargs):
        """Benchmark function execution."""
    
    def solve(self, a, b):
        """Solve linear system ax = b."""
    
    def trace(self, a):
        """Sum along diagonals."""
    
    def inv(self, a):
        """Matrix inverse."""
    
    def sqrtm(self, a):
        """Matrix square root."""
    
    def isfinite(self, a):
        """Test for finite elements."""
    
    def array_equal(self, a, b):
        """Test if arrays are equal."""
    
    def is_floating_point(self, a):
        """Test if array has floating point dtype."""

NumPy Backend

class ot.backend.NumpyBackend(Backend):
    """
    NumPy backend implementation.
    
    Provides NumPy-based implementations of all backend operations.
    This is the default backend and reference implementation.
    
    Features:
    - CPU computation using NumPy
    - Full scipy sparse matrix support
    - Mature and stable implementation
    - No automatic differentiation
    """
    
    def __init__(self):
        """Initialize NumPy backend."""
        self.name = 'numpy'
        self.__name__ = 'numpy'
        
    # All methods delegate to NumPy equivalents
    def zeros(self, shape, type_as=None):
        return np.zeros(shape, dtype=self._get_dtype(type_as))
        
    def ones(self, shape, type_as=None):
        return np.ones(shape, dtype=self._get_dtype(type_as))
        
    # ... (implements all Backend methods using NumPy)

PyTorch Backend

class ot.backend.TorchBackend(Backend):
    """
    PyTorch backend implementation.
    
    Enables GPU acceleration and automatic differentiation for POT algorithms
    when using PyTorch tensors as input.
    
    Features:
    - GPU computation via CUDA
    - Automatic differentiation support
    - Integration with PyTorch ecosystem
    - Batch operations for deep learning
    
    Note: Requires PyTorch installation. Some operations may have
    different numerical behavior compared to NumPy.
    """
    
    def __init__(self):
        """Initialize PyTorch backend."""
        self.name = 'torch'
        self.__name__ = 'torch'
        
    def zeros(self, shape, type_as=None):
        if type_as is not None:
            return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
        return torch.zeros(shape)
        
    def dot(self, a, b):
        return torch.mm(a, b)
        
    def solve(self, a, b):
        return torch.linalg.solve(a, b)
        
    # Gradient computation support
    def requires_grad(self, a, requires_grad=True):
        """Enable/disable gradient computation for tensor."""
        a.requires_grad_(requires_grad)
        
    def detach(self, a):
        """Detach tensor from computation graph."""
        return a.detach()

JAX Backend

class ot.backend.JaxBackend(Backend):
    """
    JAX backend implementation.
    
    Provides JAX-based computation with automatic differentiation,
    just-in-time compilation, and GPU/TPU support.
    
    Features:
    - JIT compilation for performance
    - Automatic differentiation (forward/reverse mode)
    - GPU/TPU acceleration
    - Functional programming paradigm
    - NumPy-compatible API
    
    Note: Requires JAX installation. Arrays are immutable.
    """
    
    def __init__(self):
        """Initialize JAX backend."""
        self.name = 'jax'
        self.__name__ = 'jax'
        
    def zeros(self, shape, type_as=None):
        return jnp.zeros(shape, dtype=self._get_dtype(type_as))
        
    def dot(self, a, b):
        return jnp.dot(a, b)
        
    def solve(self, a, b):
        return jnp.linalg.solve(a, b)
        
    # JAX-specific features
    def jit(self, fun):
        """Apply JIT compilation to function."""
        return jax.jit(fun)
        
    def grad(self, fun):
        """Compute gradient of function."""
        return jax.grad(fun)

TensorFlow Backend

class ot.backend.TensorflowBackend(Backend):
    """
    TensorFlow backend implementation.
    
    Enables integration with TensorFlow ecosystem including Keras models
    and TensorFlow Probability.
    
    Features:
    - GPU acceleration
    - Automatic differentiation via GradientTape
    - Integration with Keras/TensorFlow ecosystem  
    - Graph execution and eager mode support
    
    Note: Requires TensorFlow installation.
    """
    
    def __init__(self):
        """Initialize TensorFlow backend."""
        self.name = 'tensorflow'
        self.__name__ = 'tensorflow'
        
    def zeros(self, shape, type_as=None):
        return tf.zeros(shape, dtype=self._get_dtype(type_as))
        
    def dot(self, a, b):
        return tf.linalg.matmul(a, b)
        
    def solve(self, a, b):
        return tf.linalg.solve(a, b)

CuPy Backend

class ot.backend.CupyBackend(Backend):
    """
    CuPy backend implementation.
    
    Provides NumPy-compatible GPU acceleration using CuPy for CUDA GPUs.
    
    Features:
    - NumPy-compatible API on GPU
    - High performance GPU kernels
    - Sparse matrix support on GPU
    - Memory pool management
    
    Note: Requires CuPy installation and NVIDIA GPU.
    """
    
    def __init__(self):
        """Initialize CuPy backend."""
        self.name = 'cupy'
        self.__name__ = 'cupy'
        
    def zeros(self, shape, type_as=None):
        return cp.zeros(shape, dtype=self._get_dtype(type_as))
        
    def dot(self, a, b):
        return cp.dot(a, b)
        
    def solve(self, a, b):
        return cp.linalg.solve(a, b)

Backend Configuration

Environment Variables

# Backend configuration constants
DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"  
DISABLE_CUPY_KEY = "POT_BACKEND_DISABLE_CUPY"
DISABLE_TF_KEY = "POT_BACKEND_DISABLE_TENSORFLOW"

"""
Environment variables to disable specific backends:

export POT_BACKEND_DISABLE_PYTORCH=1    # Disable PyTorch backend
export POT_BACKEND_DISABLE_JAX=1        # Disable JAX backend  
export POT_BACKEND_DISABLE_CUPY=1       # Disable CuPy backend
export POT_BACKEND_DISABLE_TENSORFLOW=1 # Disable TensorFlow backend

This is useful for:
- Avoiding import errors when packages are not installed
- Forcing use of specific backends
- Debugging backend-related issues
"""

Internal Backend Functions

def ot.backend._register_backend_implementation(backend_impl):
    """
    Register a new backend implementation.
    
    Internal function for adding custom backend implementations
    to the POT backend system.
    
    Parameters:
    - backend_impl: Backend class
         Backend implementation inheriting from Backend base class.
    """

def ot.backend._get_backend_instance(backend_impl):
    """
    Get singleton instance of backend implementation.
    
    Parameters:
    - backend_impl: str or class
         Backend identifier or implementation class.
    
    Returns:
    - backend: Backend instance
    """

def ot.backend._check_args_backend(backend_impl, args):
    """
    Check if arrays are compatible with specified backend.
    
    Parameters:
    - backend_impl: Backend class
    - args: sequence of arrays
         Arrays to check compatibility for.
    
    Returns:
    - compatible: bool
         Whether arrays are compatible with backend.
    """

Usage Examples

Automatic Backend Detection

import ot
import numpy as np

# NumPy arrays - automatic backend selection
a_np = np.array([0.5, 0.5])
b_np = np.array([0.3, 0.7])
M_np = np.array([[1.0, 2.0], [2.0, 1.0]])

# POT automatically uses NumPy backend
plan_np = ot.sinkhorn(a_np, b_np, M_np, reg=0.1)
print(f"NumPy result type: {type(plan_np)}")

# Check which backend was used
nx = ot.backend.get_backend(a_np, b_np, M_np)
print(f"Backend used: {nx.__name__}")

PyTorch Integration with GPU

import torch

# Check if CUDA is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Create PyTorch tensors on GPU
a_torch = torch.tensor([0.5, 0.5], device=device, dtype=torch.float64)
b_torch = torch.tensor([0.3, 0.7], device=device, dtype=torch.float64)
M_torch = torch.tensor([[1.0, 2.0], [2.0, 1.0]], device=device, dtype=torch.float64)

# Enable gradients for optimization
a_torch.requires_grad_(True)
M_torch.requires_grad_(True)

# POT automatically uses PyTorch backend
plan_torch = ot.sinkhorn(a_torch, b_torch, M_torch, reg=0.1)
print(f"PyTorch result type: {type(plan_torch)}")
print(f"Result device: {plan_torch.device}")

# Compute gradients
loss = torch.sum(plan_torch * M_torch)
loss.backward()

print(f"Gradient w.r.t. a: {a_torch.grad}")
print(f"Gradient w.r.t. M: {M_torch.grad}")

JAX Integration with JIT

import jax.numpy as jnp
import jax

# Create JAX arrays
a_jax = jnp.array([0.5, 0.5])
b_jax = jnp.array([0.3, 0.7])
M_jax = jnp.array([[1.0, 2.0], [2.0, 1.0]])

# Define function for JIT compilation
def compute_sinkhorn(a, b, M):
    return ot.sinkhorn(a, b, M, reg=0.1)

# JIT compile the function
jit_sinkhorn = jax.jit(compute_sinkhorn)

# First call compiles, subsequent calls are fast
plan_jax = jit_sinkhorn(a_jax, b_jax, M_jax)
print(f"JAX result type: {type(plan_jax)}")

# Compute gradients
grad_fn = jax.grad(lambda a, b, M: jnp.sum(ot.sinkhorn(a, b, M, reg=0.1)), argnums=0)
gradient = grad_fn(a_jax, b_jax, M_jax)
print(f"Gradient w.r.t. a: {gradient}")

Backend Conversion

# Convert between different backends
a_np = np.array([0.5, 0.5])

# Convert to different backends
if torch.cuda.is_available():
    a_torch = torch.from_numpy(a_np).cuda()
    print(f"Converted to PyTorch GPU: {a_torch.device}")

# Universal conversion back to NumPy
a_converted = ot.backend.to_numpy(a_torch)
print(f"Converted back to NumPy: {type(a_converted)}")

Multi-Backend Computation

# Compare results across backends
backends_to_test = []

# NumPy (always available)
backends_to_test.append(('numpy', a_np, b_np, M_np))

# PyTorch (if available)
try:
    backends_to_test.append(('torch', a_torch.cpu(), b_torch.cpu(), M_torch.cpu()))
except:
    pass

# JAX (if available)
try:
    backends_to_test.append(('jax', a_jax, b_jax, M_jax))
except:
    pass

results = {}
for backend_name, a, b, M in backends_to_test:
    plan = ot.sinkhorn(a, b, M, reg=0.1)
    # Convert to numpy for comparison
    plan_np = ot.backend.to_numpy(plan)
    results[backend_name] = plan_np
    print(f"{backend_name} backend - sum: {np.sum(plan_np):.6f}")

# Check consistency across backends
if len(results) > 1:
    backend_names = list(results.keys())
    for i in range(len(backend_names)):
        for j in range(i+1, len(backend_names)):
            name1, name2 = backend_names[i], backend_names[j]
            diff = np.max(np.abs(results[name1] - results[name2]))
            print(f"Max difference {name1} vs {name2}: {diff:.2e}")

Custom Backend Implementation

# Example of extending backend system (advanced usage)
class CustomBackend(ot.backend.Backend):
    """Custom backend example."""
    
    def __init__(self):
        self.name = 'custom'
        
    def zeros(self, shape, type_as=None):
        # Custom implementation
        return np.zeros(shape)  # Simplified example
        
    # Implement other required methods...

# Register custom backend (internal API)
# ot.backend._register_backend_implementation(CustomBackend)

Performance Benchmarking

import time

def benchmark_sinkhorn(a, b, M, backend_name, n_runs=10):
    """Benchmark Sinkhorn algorithm on different backends."""
    times = []
    
    for _ in range(n_runs):
        start = time.time()
        plan = ot.sinkhorn(a, b, M, reg=0.1, numItermax=100)
        
        # Ensure computation is complete (important for GPU)
        if hasattr(plan, 'cpu'):
            _ = plan.cpu().numpy()  # PyTorch
        elif hasattr(plan, 'block_until_ready'):
            plan.block_until_ready()  # JAX
        else:
            _ = np.array(plan)  # NumPy/CuPy
            
        times.append(time.time() - start)
    
    return np.mean(times), np.std(times)

# Benchmark across available backends
n = 100
a_large = ot.unif(n)
b_large = ot.unif(n)  
M_large = np.random.rand(n, n)

print(f"Benchmarking Sinkhorn ({n}x{n} problem):")
mean_time, std_time = benchmark_sinkhorn(a_large, b_large, M_large, 'numpy')
print(f"NumPy: {mean_time:.4f} ± {std_time:.4f} seconds")

if torch.cuda.is_available():
    a_cuda = torch.from_numpy(a_large).cuda()
    b_cuda = torch.from_numpy(b_large).cuda()
    M_cuda = torch.from_numpy(M_large).cuda()
    mean_time, std_time = benchmark_sinkhorn(a_cuda, b_cuda, M_cuda, 'torch-cuda')
    print(f"PyTorch GPU: {mean_time:.4f} ± {std_time:.4f} seconds")

Key Benefits

Framework Flexibility

  • Seamless Integration: Use POT with any supported ML framework
  • Automatic Detection: No manual backend selection required
  • Consistent API: Same function calls work across all backends

Performance Optimization

  • GPU Acceleration: Automatic GPU usage when arrays are on GPU
  • JIT Compilation: JAX backend provides automatic compilation
  • Memory Efficiency: Backend-appropriate memory management

Automatic Differentiation

  • PyTorch: Full autograd support for deep learning
  • JAX: Forward and reverse mode AD with functional programming
  • TensorFlow: Integration with GradientTape and Keras

Production Deployment

  • Scalability: Choose appropriate backend for deployment needs
  • Hardware Support: CPU, GPU, TPU support depending on backend
  • Framework Ecosystem: Leverage existing ML infrastructure

The backend system makes POT a truly framework-agnostic optimal transport library, enabling efficient computation across the entire machine learning ecosystem while maintaining a unified API.

Install with Tessl CLI

npx tessl i tessl/pypi-pot

docs

advanced-methods.md

backend-system.md

domain-adaptation.md

entropic-transport.md

factored-transport.md

gromov-wasserstein.md

index.md

linear-programming.md

partial-transport.md

regularization-path.md

sliced-wasserstein.md

smooth-transport.md

stochastic-solvers.md

unbalanced-transport.md

unified-solvers.md

utilities.md

weak-transport.md

tile.json