CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

assignment.mddocs/

Assignment Operations

Assignment algorithms for solving linear assignment problems. The assignment submodule provides implementations of classical algorithms for optimally pairing elements from two sets.

Capabilities

Hungarian Algorithm

The Hungarian algorithm for solving the linear assignment problem, which finds the minimum-cost way to assign elements from one set to another.

def hungarian_algorithm(cost_matrix):
    """
    The Hungarian algorithm for the linear assignment problem.
    
    Solves the problem of finding a minimum-cost assignment given a cost matrix.
    Given an n×m cost matrix, finds an assignment that minimizes the total cost
    while ensuring each row and column is assigned at most once.
    
    Args:
        cost_matrix: A matrix of costs (jax.Array with shape (n, m))
    
    Returns:
        A tuple (i, j) where i is an array of row indices and j is an array 
        of column indices representing the optimal assignment.
        The cost of the assignment is cost_matrix[i, j].sum().
    """

def base_hungarian_algorithm(cost_matrix):
    """
    Base implementation of the Hungarian algorithm.
    
    Lower-level implementation that provides the core Hungarian algorithm
    functionality without additional conveniences.
    
    Args:
        cost_matrix: A matrix of costs (jax.Array with shape (n, m))
    
    Returns:
        A tuple (i, j) where i is an array of row indices and j is an array 
        of column indices representing the optimal assignment.
    """

Usage Examples

import optax
import jax.numpy as jnp

# Example 1: Basic assignment problem
cost_matrix = jnp.array([
    [8, 4, 7],
    [5, 2, 3],
    [9, 6, 7],
    [9, 4, 8],
])

# Find optimal assignment
row_indices, col_indices = optax.assignment.hungarian_algorithm(cost_matrix)

# Calculate total cost
total_cost = cost_matrix[row_indices, col_indices].sum()
print(f"Optimal assignment cost: {total_cost}")
print(f"Row assignments: {row_indices}")
print(f"Column assignments: {col_indices}")

# Example 2: Larger assignment problem
cost_matrix = jnp.array([
    [90, 80, 75, 70],
    [35, 85, 55, 65],
    [125, 95, 90, 95],
    [45, 110, 95, 115],
    [50, 100, 90, 100],
])

row_indices, col_indices = optax.assignment.hungarian_algorithm(cost_matrix)
total_cost = cost_matrix[row_indices, col_indices].sum()
print(f"Optimal assignment cost: {total_cost}")

Problem Description

The linear assignment problem can be formally stated as:

Given a cost matrix C ∈ ℝⁿˣᵐ, solve the integer linear program:

  • Minimize: ∑ᵢ ∑ⱼ Cᵢⱼ Xᵢⱼ
  • Subject to:
    • Xᵢⱼ ∈ {0, 1} for all i, j
    • ∑ᵢ Xᵢⱼ ≤ 1 for all j (each column assigned at most once)
    • ∑ⱼ Xᵢⱼ ≤ 1 for all i (each row assigned at most once)
    • ∑ᵢ ∑ⱼ Xᵢⱼ = min(n, m) (maximum cardinality matching)

The Hungarian algorithm solves this problem in O(n³) time complexity.

Import

import optax.assignment
# or
from optax.assignment import hungarian_algorithm, base_hungarian_algorithm

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json