Python Optimal Transport Library providing solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The ot.backend module provides a unified interface for multi-framework computation, enabling POT to work seamlessly with NumPy, PyTorch, JAX, TensorFlow, and CuPy. This backend system allows for automatic differentiation, GPU acceleration, and integration with deep learning frameworks.
def ot.backend.get_backend(*args):
"""
Get appropriate backend for input arrays.
Automatically detects the backend type from input arrays and returns
the corresponding backend instance. Enables framework-agnostic code
that works across NumPy, PyTorch, JAX, etc.
Parameters:
- args: sequence of arrays
Input arrays to determine backend from. The function analyzes
the type of these arrays to select the appropriate backend.
Returns:
- backend: Backend instance
Backend object providing unified array operations interface.
Example:
import numpy as np
import torch
# NumPy arrays -> NumPy backend
a_np = np.array([1, 2, 3])
nx = ot.backend.get_backend(a_np)
print(type(nx)) # <class 'ot.backend.NumpyBackend'>
# PyTorch tensors -> PyTorch backend
a_torch = torch.tensor([1, 2, 3])
nx = ot.backend.get_backend(a_torch)
print(type(nx)) # <class 'ot.backend.TorchBackend'>
"""
def ot.backend.get_backend_list():
"""
List all available backends in the current environment.
Returns list of backend names that are available based on installed
packages and environment configuration.
Returns:
- backends: list of str
Names of available backends: 'numpy', 'torch', 'jax', 'tf', 'cupy'
Example:
available = ot.backend.get_backend_list()
print("Available backends:", available)
# Output: ['numpy', 'torch', 'jax'] (depending on installation)
"""
def ot.backend.get_available_backend_implementations():
"""
Get detailed information about available backend implementations.
Returns:
- implementations: dict
Dictionary mapping backend names to their implementation details,
versions, and availability status.
"""
def ot.backend.to_numpy(*args):
"""
Convert arrays to numpy format regardless of input backend.
Universal converter that transforms arrays from any supported backend
(PyTorch tensors, JAX arrays, etc.) to NumPy arrays.
Parameters:
- args: sequence of arrays
Arrays in any supported backend format.
Returns:
- numpy_arrays: tuple of numpy.ndarray
Arrays converted to NumPy format. Returns single array if only
one input, otherwise returns tuple.
Example:
import torch
a_torch = torch.tensor([1.0, 2.0, 3.0])
b_torch = torch.tensor([4.0, 5.0, 6.0])
a_np, b_np = ot.backend.to_numpy(a_torch, b_torch)
print(type(a_np)) # <class 'numpy.ndarray'>
"""class ot.backend.Backend:
"""
Base backend class defining the unified interface for array operations.
All backend implementations inherit from this class and implement
the required methods for array manipulation, linear algebra, and
optimization operations needed by POT algorithms.
"""
def zeros(self, shape, type_as=None):
"""Create array filled with zeros."""
def ones(self, shape, type_as=None):
"""Create array filled with ones."""
def full(self, shape, fill_value, type_as=None):
"""Create array filled with specified value."""
def eye(self, N, M=None, type_as=None):
"""Create identity matrix."""
def sum(self, a, axis=None, keepdims=False):
"""Sum array elements along specified axis."""
def cumsum(self, a, axis=None):
"""Cumulative sum along axis."""
def max(self, a, axis=None, keepdims=False):
"""Maximum values along axis."""
def min(self, a, axis=None, keepdims=False):
"""Minimum values along axis."""
def dot(self, a, b):
"""Matrix multiplication."""
def norm(self, a, axis=None, keepdims=False):
"""Compute norm of array."""
def exp(self, a):
"""Element-wise exponential."""
def log(self, a):
"""Element-wise natural logarithm."""
def sqrt(self, a):
"""Element-wise square root."""
def sort(self, a, axis=-1):
"""Sort array along axis."""
def argsort(self, a, axis=-1):
"""Indices that would sort array."""
def searchsorted(self, a, v, side='left'):
"""Find indices for sorted insertion."""
def flip(self, a, axis=None):
"""Reverse array along axis."""
def clip(self, a, a_min, a_max):
"""Clip array values to range."""
def repeat(self, a, repeats, axis=None):
"""Repeat array elements."""
def take_along_axis(self, a, indices, axis):
"""Take values along axis using indices."""
def concatenate(self, arrays, axis=0):
"""Join arrays along axis."""
def zero_pad(self, a, pad_width):
"""Pad array with zeros."""
def argmax(self, a, axis=None):
"""Indices of maximum values."""
def argmin(self, a, axis=None):
"""Indices of minimum values."""
def mean(self, a, axis=None):
"""Mean along axis."""
def std(self, a, axis=None):
"""Standard deviation along axis."""
def linspace(self, start, stop, num):
"""Create evenly spaced numbers."""
def meshgrid(self, a, b):
"""Create coordinate matrices."""
def diag(self, a, k=0):
"""Extract or create diagonal."""
def unique(self, a):
"""Find unique elements."""
def logsumexp(self, a, axis=None):
"""Log of sum of exponentials."""
def stack(self, arrays, axis=0):
"""Join arrays along new axis."""
def reshape(self, a, shape):
"""Change array shape."""
def seed(self, seed=None):
"""Set random seed."""
def rand(self, *args, **kwargs):
"""Random values in [0,1)."""
def randn(self, *args, **kwargs):
"""Random values from standard normal."""
def coo_matrix(self, S, shape=None):
"""Create sparse COO matrix."""
def issparse(self, a):
"""Check if array is sparse."""
def tocsr(self, a):
"""Convert to CSR sparse format."""
def eliminate_zeros(self, a):
"""Remove explicit zeros from sparse matrix."""
def todense(self, a):
"""Convert sparse to dense."""
def where(self, condition, x=None, y=None):
"""Select elements based on condition."""
def copy(self, a):
"""Create copy of array."""
def allclose(self, a, b, rtol=1e-05, atol=1e-08):
"""Test if arrays are element-wise equal within tolerance."""
def dtype_device(self, a):
"""Get dtype and device info."""
def assert_same_dtype_device(self, a, b):
"""Assert arrays have same dtype and device."""
def squeeze(self, a, axis=None):
"""Remove single-dimensional entries."""
def bitsize(self, type_as):
"""Get bit size of data type."""
def device_type(self, type_as):
"""Get device type (cpu, cuda, etc.)."""
def _bench(self, callable, *args, **kwargs):
"""Benchmark function execution."""
def solve(self, a, b):
"""Solve linear system ax = b."""
def trace(self, a):
"""Sum along diagonals."""
def inv(self, a):
"""Matrix inverse."""
def sqrtm(self, a):
"""Matrix square root."""
def isfinite(self, a):
"""Test for finite elements."""
def array_equal(self, a, b):
"""Test if arrays are equal."""
def is_floating_point(self, a):
"""Test if array has floating point dtype."""class ot.backend.NumpyBackend(Backend):
"""
NumPy backend implementation.
Provides NumPy-based implementations of all backend operations.
This is the default backend and reference implementation.
Features:
- CPU computation using NumPy
- Full scipy sparse matrix support
- Mature and stable implementation
- No automatic differentiation
"""
def __init__(self):
"""Initialize NumPy backend."""
self.name = 'numpy'
self.__name__ = 'numpy'
# All methods delegate to NumPy equivalents
def zeros(self, shape, type_as=None):
return np.zeros(shape, dtype=self._get_dtype(type_as))
def ones(self, shape, type_as=None):
return np.ones(shape, dtype=self._get_dtype(type_as))
# ... (implements all Backend methods using NumPy)class ot.backend.TorchBackend(Backend):
"""
PyTorch backend implementation.
Enables GPU acceleration and automatic differentiation for POT algorithms
when using PyTorch tensors as input.
Features:
- GPU computation via CUDA
- Automatic differentiation support
- Integration with PyTorch ecosystem
- Batch operations for deep learning
Note: Requires PyTorch installation. Some operations may have
different numerical behavior compared to NumPy.
"""
def __init__(self):
"""Initialize PyTorch backend."""
self.name = 'torch'
self.__name__ = 'torch'
def zeros(self, shape, type_as=None):
if type_as is not None:
return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
return torch.zeros(shape)
def dot(self, a, b):
return torch.mm(a, b)
def solve(self, a, b):
return torch.linalg.solve(a, b)
# Gradient computation support
def requires_grad(self, a, requires_grad=True):
"""Enable/disable gradient computation for tensor."""
a.requires_grad_(requires_grad)
def detach(self, a):
"""Detach tensor from computation graph."""
return a.detach()class ot.backend.JaxBackend(Backend):
"""
JAX backend implementation.
Provides JAX-based computation with automatic differentiation,
just-in-time compilation, and GPU/TPU support.
Features:
- JIT compilation for performance
- Automatic differentiation (forward/reverse mode)
- GPU/TPU acceleration
- Functional programming paradigm
- NumPy-compatible API
Note: Requires JAX installation. Arrays are immutable.
"""
def __init__(self):
"""Initialize JAX backend."""
self.name = 'jax'
self.__name__ = 'jax'
def zeros(self, shape, type_as=None):
return jnp.zeros(shape, dtype=self._get_dtype(type_as))
def dot(self, a, b):
return jnp.dot(a, b)
def solve(self, a, b):
return jnp.linalg.solve(a, b)
# JAX-specific features
def jit(self, fun):
"""Apply JIT compilation to function."""
return jax.jit(fun)
def grad(self, fun):
"""Compute gradient of function."""
return jax.grad(fun)class ot.backend.TensorflowBackend(Backend):
"""
TensorFlow backend implementation.
Enables integration with TensorFlow ecosystem including Keras models
and TensorFlow Probability.
Features:
- GPU acceleration
- Automatic differentiation via GradientTape
- Integration with Keras/TensorFlow ecosystem
- Graph execution and eager mode support
Note: Requires TensorFlow installation.
"""
def __init__(self):
"""Initialize TensorFlow backend."""
self.name = 'tensorflow'
self.__name__ = 'tensorflow'
def zeros(self, shape, type_as=None):
return tf.zeros(shape, dtype=self._get_dtype(type_as))
def dot(self, a, b):
return tf.linalg.matmul(a, b)
def solve(self, a, b):
return tf.linalg.solve(a, b)class ot.backend.CupyBackend(Backend):
"""
CuPy backend implementation.
Provides NumPy-compatible GPU acceleration using CuPy for CUDA GPUs.
Features:
- NumPy-compatible API on GPU
- High performance GPU kernels
- Sparse matrix support on GPU
- Memory pool management
Note: Requires CuPy installation and NVIDIA GPU.
"""
def __init__(self):
"""Initialize CuPy backend."""
self.name = 'cupy'
self.__name__ = 'cupy'
def zeros(self, shape, type_as=None):
return cp.zeros(shape, dtype=self._get_dtype(type_as))
def dot(self, a, b):
return cp.dot(a, b)
def solve(self, a, b):
return cp.linalg.solve(a, b)# Backend configuration constants
DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"
DISABLE_CUPY_KEY = "POT_BACKEND_DISABLE_CUPY"
DISABLE_TF_KEY = "POT_BACKEND_DISABLE_TENSORFLOW"
"""
Environment variables to disable specific backends:
export POT_BACKEND_DISABLE_PYTORCH=1 # Disable PyTorch backend
export POT_BACKEND_DISABLE_JAX=1 # Disable JAX backend
export POT_BACKEND_DISABLE_CUPY=1 # Disable CuPy backend
export POT_BACKEND_DISABLE_TENSORFLOW=1 # Disable TensorFlow backend
This is useful for:
- Avoiding import errors when packages are not installed
- Forcing use of specific backends
- Debugging backend-related issues
"""def ot.backend._register_backend_implementation(backend_impl):
"""
Register a new backend implementation.
Internal function for adding custom backend implementations
to the POT backend system.
Parameters:
- backend_impl: Backend class
Backend implementation inheriting from Backend base class.
"""
def ot.backend._get_backend_instance(backend_impl):
"""
Get singleton instance of backend implementation.
Parameters:
- backend_impl: str or class
Backend identifier or implementation class.
Returns:
- backend: Backend instance
"""
def ot.backend._check_args_backend(backend_impl, args):
"""
Check if arrays are compatible with specified backend.
Parameters:
- backend_impl: Backend class
- args: sequence of arrays
Arrays to check compatibility for.
Returns:
- compatible: bool
Whether arrays are compatible with backend.
"""import ot
import numpy as np
# NumPy arrays - automatic backend selection
a_np = np.array([0.5, 0.5])
b_np = np.array([0.3, 0.7])
M_np = np.array([[1.0, 2.0], [2.0, 1.0]])
# POT automatically uses NumPy backend
plan_np = ot.sinkhorn(a_np, b_np, M_np, reg=0.1)
print(f"NumPy result type: {type(plan_np)}")
# Check which backend was used
nx = ot.backend.get_backend(a_np, b_np, M_np)
print(f"Backend used: {nx.__name__}")import torch
# Check if CUDA is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Create PyTorch tensors on GPU
a_torch = torch.tensor([0.5, 0.5], device=device, dtype=torch.float64)
b_torch = torch.tensor([0.3, 0.7], device=device, dtype=torch.float64)
M_torch = torch.tensor([[1.0, 2.0], [2.0, 1.0]], device=device, dtype=torch.float64)
# Enable gradients for optimization
a_torch.requires_grad_(True)
M_torch.requires_grad_(True)
# POT automatically uses PyTorch backend
plan_torch = ot.sinkhorn(a_torch, b_torch, M_torch, reg=0.1)
print(f"PyTorch result type: {type(plan_torch)}")
print(f"Result device: {plan_torch.device}")
# Compute gradients
loss = torch.sum(plan_torch * M_torch)
loss.backward()
print(f"Gradient w.r.t. a: {a_torch.grad}")
print(f"Gradient w.r.t. M: {M_torch.grad}")import jax.numpy as jnp
import jax
# Create JAX arrays
a_jax = jnp.array([0.5, 0.5])
b_jax = jnp.array([0.3, 0.7])
M_jax = jnp.array([[1.0, 2.0], [2.0, 1.0]])
# Define function for JIT compilation
def compute_sinkhorn(a, b, M):
return ot.sinkhorn(a, b, M, reg=0.1)
# JIT compile the function
jit_sinkhorn = jax.jit(compute_sinkhorn)
# First call compiles, subsequent calls are fast
plan_jax = jit_sinkhorn(a_jax, b_jax, M_jax)
print(f"JAX result type: {type(plan_jax)}")
# Compute gradients
grad_fn = jax.grad(lambda a, b, M: jnp.sum(ot.sinkhorn(a, b, M, reg=0.1)), argnums=0)
gradient = grad_fn(a_jax, b_jax, M_jax)
print(f"Gradient w.r.t. a: {gradient}")# Convert between different backends
a_np = np.array([0.5, 0.5])
# Convert to different backends
if torch.cuda.is_available():
a_torch = torch.from_numpy(a_np).cuda()
print(f"Converted to PyTorch GPU: {a_torch.device}")
# Universal conversion back to NumPy
a_converted = ot.backend.to_numpy(a_torch)
print(f"Converted back to NumPy: {type(a_converted)}")# Compare results across backends
backends_to_test = []
# NumPy (always available)
backends_to_test.append(('numpy', a_np, b_np, M_np))
# PyTorch (if available)
try:
backends_to_test.append(('torch', a_torch.cpu(), b_torch.cpu(), M_torch.cpu()))
except:
pass
# JAX (if available)
try:
backends_to_test.append(('jax', a_jax, b_jax, M_jax))
except:
pass
results = {}
for backend_name, a, b, M in backends_to_test:
plan = ot.sinkhorn(a, b, M, reg=0.1)
# Convert to numpy for comparison
plan_np = ot.backend.to_numpy(plan)
results[backend_name] = plan_np
print(f"{backend_name} backend - sum: {np.sum(plan_np):.6f}")
# Check consistency across backends
if len(results) > 1:
backend_names = list(results.keys())
for i in range(len(backend_names)):
for j in range(i+1, len(backend_names)):
name1, name2 = backend_names[i], backend_names[j]
diff = np.max(np.abs(results[name1] - results[name2]))
print(f"Max difference {name1} vs {name2}: {diff:.2e}")# Example of extending backend system (advanced usage)
class CustomBackend(ot.backend.Backend):
"""Custom backend example."""
def __init__(self):
self.name = 'custom'
def zeros(self, shape, type_as=None):
# Custom implementation
return np.zeros(shape) # Simplified example
# Implement other required methods...
# Register custom backend (internal API)
# ot.backend._register_backend_implementation(CustomBackend)import time
def benchmark_sinkhorn(a, b, M, backend_name, n_runs=10):
"""Benchmark Sinkhorn algorithm on different backends."""
times = []
for _ in range(n_runs):
start = time.time()
plan = ot.sinkhorn(a, b, M, reg=0.1, numItermax=100)
# Ensure computation is complete (important for GPU)
if hasattr(plan, 'cpu'):
_ = plan.cpu().numpy() # PyTorch
elif hasattr(plan, 'block_until_ready'):
plan.block_until_ready() # JAX
else:
_ = np.array(plan) # NumPy/CuPy
times.append(time.time() - start)
return np.mean(times), np.std(times)
# Benchmark across available backends
n = 100
a_large = ot.unif(n)
b_large = ot.unif(n)
M_large = np.random.rand(n, n)
print(f"Benchmarking Sinkhorn ({n}x{n} problem):")
mean_time, std_time = benchmark_sinkhorn(a_large, b_large, M_large, 'numpy')
print(f"NumPy: {mean_time:.4f} ± {std_time:.4f} seconds")
if torch.cuda.is_available():
a_cuda = torch.from_numpy(a_large).cuda()
b_cuda = torch.from_numpy(b_large).cuda()
M_cuda = torch.from_numpy(M_large).cuda()
mean_time, std_time = benchmark_sinkhorn(a_cuda, b_cuda, M_cuda, 'torch-cuda')
print(f"PyTorch GPU: {mean_time:.4f} ± {std_time:.4f} seconds")The backend system makes POT a truly framework-agnostic optimal transport library, enabling efficient computation across the entire machine learning ecosystem while maintaining a unified API.
Install with Tessl CLI
npx tessl i tessl/pypi-potdocs