0
# Assignment Operations
1
2
Assignment algorithms for solving linear assignment problems. The assignment submodule provides implementations of classical algorithms for optimally pairing elements from two sets.
3
4
## Capabilities
5
6
### Hungarian Algorithm
7
8
The Hungarian algorithm for solving the linear assignment problem, which finds the minimum-cost way to assign elements from one set to another.
9
10
```python { .api }
11
def hungarian_algorithm(cost_matrix):
12
"""
13
The Hungarian algorithm for the linear assignment problem.
14
15
Solves the problem of finding a minimum-cost assignment given a cost matrix.
16
Given an n×m cost matrix, finds an assignment that minimizes the total cost
17
while ensuring each row and column is assigned at most once.
18
19
Args:
20
cost_matrix: A matrix of costs (jax.Array with shape (n, m))
21
22
Returns:
23
A tuple (i, j) where i is an array of row indices and j is an array
24
of column indices representing the optimal assignment.
25
The cost of the assignment is cost_matrix[i, j].sum().
26
"""
27
28
def base_hungarian_algorithm(cost_matrix):
29
"""
30
Base implementation of the Hungarian algorithm.
31
32
Lower-level implementation that provides the core Hungarian algorithm
33
functionality without additional conveniences.
34
35
Args:
36
cost_matrix: A matrix of costs (jax.Array with shape (n, m))
37
38
Returns:
39
A tuple (i, j) where i is an array of row indices and j is an array
40
of column indices representing the optimal assignment.
41
"""
42
```
43
44
### Usage Examples
45
46
```python
47
import optax
48
import jax.numpy as jnp
49
50
# Example 1: Basic assignment problem
51
cost_matrix = jnp.array([
52
[8, 4, 7],
53
[5, 2, 3],
54
[9, 6, 7],
55
[9, 4, 8],
56
])
57
58
# Find optimal assignment
59
row_indices, col_indices = optax.assignment.hungarian_algorithm(cost_matrix)
60
61
# Calculate total cost
62
total_cost = cost_matrix[row_indices, col_indices].sum()
63
print(f"Optimal assignment cost: {total_cost}")
64
print(f"Row assignments: {row_indices}")
65
print(f"Column assignments: {col_indices}")
66
67
# Example 2: Larger assignment problem
68
cost_matrix = jnp.array([
69
[90, 80, 75, 70],
70
[35, 85, 55, 65],
71
[125, 95, 90, 95],
72
[45, 110, 95, 115],
73
[50, 100, 90, 100],
74
])
75
76
row_indices, col_indices = optax.assignment.hungarian_algorithm(cost_matrix)
77
total_cost = cost_matrix[row_indices, col_indices].sum()
78
print(f"Optimal assignment cost: {total_cost}")
79
```
80
81
## Problem Description
82
83
The linear assignment problem can be formally stated as:
84
85
Given a cost matrix C ∈ ℝⁿˣᵐ, solve the integer linear program:
86
87
- **Minimize**: ∑ᵢ ∑ⱼ Cᵢⱼ Xᵢⱼ
88
- **Subject to**:
89
- Xᵢⱼ ∈ {0, 1} for all i, j
90
- ∑ᵢ Xᵢⱼ ≤ 1 for all j (each column assigned at most once)
91
- ∑ⱼ Xᵢⱼ ≤ 1 for all i (each row assigned at most once)
92
- ∑ᵢ ∑ⱼ Xᵢⱼ = min(n, m) (maximum cardinality matching)
93
94
The Hungarian algorithm solves this problem in O(n³) time complexity.
95
96
## Import
97
98
```python
99
import optax.assignment
100
# or
101
from optax.assignment import hungarian_algorithm, base_hungarian_algorithm
102
```