Python implementations of metric learning algorithms
npx @tessl/cli install tessl/pypi-metric-learn@0.7.00
# Metric-Learn
1
2
Python implementations of metric learning algorithms that are fully compatible with scikit-learn's API. Metric-learn provides efficient implementations of several popular supervised and weakly-supervised metric learning algorithms as part of the scikit-learn-contrib ecosystem.
3
4
## Package Information
5
6
- **Package Name**: metric-learn
7
- **Language**: Python
8
- **Installation**: `pip install metric-learn`
9
- **Dependencies**: Python 3.6+, numpy>=1.11.0, scipy>=0.17.0, scikit-learn>=0.21.3
10
11
## Core Imports
12
13
```python
14
import metric_learn
15
```
16
17
Common import pattern for specific algorithms:
18
19
```python
20
from metric_learn import LMNN, NCA, ITML, LSML
21
```
22
23
Import utility classes:
24
25
```python
26
from metric_learn import Constraints
27
```
28
29
## Basic Usage
30
31
```python
32
from metric_learn import LMNN
33
from sklearn.datasets import load_iris
34
from sklearn.model_selection import train_test_split
35
36
# Load sample data
37
X, y = load_iris(return_X_y=True)
38
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
39
40
# Learn a metric with LMNN
41
lmnn = LMNN(n_neighbors=3, learn_rate=1e-6)
42
lmnn.fit(X_train, y_train)
43
44
# Transform data to the learned metric space
45
X_train_transformed = lmnn.transform(X_train)
46
X_test_transformed = lmnn.transform(X_test)
47
48
# Get the learned Mahalanobis matrix
49
mahalanobis_matrix = lmnn.get_mahalanobis_matrix()
50
51
# Compute distances between pairs
52
pairs = [[0, 1], [2, 3]] # indices of pairs
53
distances = lmnn.pair_distance(X_test[pairs])
54
```
55
56
## Architecture
57
58
Metric-learn follows a hierarchical class structure based on scikit-learn patterns:
59
60
- **BaseMetricLearner**: Abstract base class defining the core metric learning interface
61
- **MahalanobisMixin**: Mixin for algorithms that learn Mahalanobis distance metrics
62
- **Classifier Mixins**: Support for pair/triplet/quadruplet classification tasks
63
- **Algorithm Classes**: Concrete implementations of specific metric learning algorithms
64
65
All algorithms implement the scikit-learn API with `fit()`, `transform()`, and specialized methods for computing distances and similarities between data points.
66
67
## Capabilities
68
69
### Supervised Metric Learning Algorithms
70
71
Core supervised algorithms that learn from labeled training data to optimize distance metrics for classification tasks.
72
73
```python { .api }
74
class LMNN(MahalanobisMixin, TransformerMixin):
75
def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, verbose=False, preprocessor=None, n_components=None, random_state=None): ...
76
def fit(self, X, y): ...
77
78
class NCA(MahalanobisMixin, TransformerMixin):
79
def __init__(self, init='auto', n_components=None, max_iter=100, tol=None, verbose=False, preprocessor=None, random_state=None): ...
80
def fit(self, X, y): ...
81
82
class LFDA(MahalanobisMixin, TransformerMixin):
83
def __init__(self, n_components=None, k=None, embedding_type='weighted', preprocessor=None): ...
84
def fit(self, X, y): ...
85
```
86
87
[Supervised Algorithms](./supervised-algorithms.md)
88
89
### Weakly-Supervised Learning Algorithms
90
91
Algorithms that learn from constraints (pairs, triplets, quadruplets) rather than explicit class labels.
92
93
```python { .api }
94
class ITML(MahalanobisMixin, TransformerMixin):
95
def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, prior='identity', verbose=False, preprocessor=None, random_state=None): ...
96
def fit(self, pairs, y): ...
97
98
class LSML(MahalanobisMixin, TransformerMixin):
99
def __init__(self, tol=1e-3, max_iter=1000, verbose=False, preprocessor=None, random_state=None): ...
100
def fit(self, pairs, y): ...
101
102
class SDML(MahalanobisMixin, TransformerMixin):
103
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, preprocessor=None, random_state=None): ...
104
def fit(self, pairs, y): ...
105
106
class RCA(MahalanobisMixin, TransformerMixin):
107
def __init__(self, n_components=None, preprocessor=None): ...
108
def fit(self, X, chunks): ...
109
110
class SCML(MahalanobisMixin, TransformerMixin):
111
def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, random_state=None): ...
112
def fit(self, triplets): ...
113
```
114
115
[Weakly-Supervised Algorithms](./weakly-supervised-algorithms.md)
116
117
### Supervised Variants of Weakly-Supervised Algorithms
118
119
Supervised versions that automatically generate constraints from class labels, combining the convenience of supervised learning with constraint-based optimization.
120
121
```python { .api }
122
class ITML_Supervised(MahalanobisMixin, TransformerMixin):
123
def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, n_constraints=None, prior='identity', verbose=False, preprocessor=None, random_state=None): ...
124
def fit(self, X, y): ...
125
126
class LSML_Supervised(MahalanobisMixin, TransformerMixin):
127
def __init__(self, tol=1e-3, max_iter=1000, prior='identity', n_constraints=None, verbose=False, preprocessor=None, random_state=None): ...
128
def fit(self, X, y): ...
129
130
class SDML_Supervised(MahalanobisMixin, TransformerMixin):
131
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, n_constraints=None, preprocessor=None, random_state=None): ...
132
def fit(self, X, y): ...
133
134
class RCA_Supervised(MahalanobisMixin, TransformerMixin):
135
def __init__(self, n_components=None, n_chunks=100, chunk_size=2, preprocessor=None, random_state=None): ...
136
def fit(self, X, y): ...
137
138
class MMC_Supervised(MahalanobisMixin, TransformerMixin):
139
def __init__(self, init='identity', max_iter=100, max_proj=10000, convergence_threshold=1e-3, n_constraints=None, diagonal=False, diagonal_c=1.0, verbose=False, preprocessor=None, random_state=None): ...
140
def fit(self, X, y): ...
141
142
class SCML_Supervised(MahalanobisMixin, TransformerMixin):
143
def __init__(self, beta=1e-5, basis='lda', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, random_state=None): ...
144
def fit(self, X, y): ...
145
```
146
147
### Specialized Algorithms
148
149
Algorithms designed for specific use cases like clustering and kernel regression.
150
151
```python { .api }
152
class MLKR(BaseMetricLearner, TransformerMixin):
153
def __init__(self, init='auto', alpha=0.1, max_iter=1000, preprocessor=None, random_state=None): ...
154
def fit(self, X, y): ...
155
156
class MMC(MahalanobisMixin, TransformerMixin):
157
def __init__(self, init='identity', max_iter=100, max_proj=10000, convergence_threshold=1e-3, num_constraints=None, diagonal=False, diagonal_c=1.0, verbose=False, preprocessor=None, random_state=None): ...
158
def fit(self, X, y): ...
159
160
class Covariance(MahalanobisMixin, TransformerMixin):
161
def __init__(self, preprocessor=None): ...
162
def fit(self, X, y=None): ...
163
```
164
165
[Specialized Algorithms](./specialized-algorithms.md)
166
167
### Utility Classes and Constraints
168
169
Helper classes for generating constraints and working with metric learning data.
170
171
```python { .api }
172
class Constraints:
173
def __init__(self, partial_labels): ...
174
def positive_negative_pairs(self, n_constraints, same_length=False, random_state=None): ...
175
def chunks(self, n_chunks=100, chunk_size=2, random_state=None): ...
176
def generate_knntriplets(self, X, k_genuine, k_impostor): ...
177
```
178
179
[Utilities](./utilities.md)
180
181
### Base Classes and Mixins
182
183
Core abstract classes and mixins that define the metric learning API.
184
185
```python { .api }
186
class BaseMetricLearner(BaseEstimator):
187
def __init__(self, preprocessor=None): ...
188
def pair_score(self, pairs): ...
189
def pair_distance(self, pairs): ...
190
def get_metric(self): ...
191
192
class MahalanobisMixin(BaseMetricLearner):
193
def transform(self, X): ...
194
def get_mahalanobis_matrix(self): ...
195
```
196
197
[Base Classes](./base-classes.md)