0
# Base Classes and Mixins
1
2
Core abstract classes and mixins that define the metric learning API. Understanding these classes is essential for using metric-learn algorithms effectively and for implementing custom metric learning algorithms.
3
4
## Capabilities
5
6
### BaseMetricLearner
7
8
Abstract base class that defines the core interface for all metric learning algorithms. All metric learning algorithms in the package inherit from this class.
9
10
```python { .api }
11
class BaseMetricLearner(BaseEstimator):
12
def __init__(self, preprocessor=None):
13
"""
14
Base constructor for metric learners.
15
16
Parameters:
17
- preprocessor: array-like or callable, preprocessor to get data from indices
18
"""
19
20
def pair_score(self, pairs):
21
"""
22
Compute similarity score between pairs.
23
24
Parameters:
25
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2),
26
3D array of pairs or 2D array of indices
27
28
Returns:
29
- scores: ndarray, shape=(n_pairs,), similarity scores (higher = more similar)
30
"""
31
32
def pair_distance(self, pairs):
33
"""
34
Compute distance between pairs.
35
36
Parameters:
37
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2),
38
3D array of pairs or 2D array of indices
39
40
Returns:
41
- distances: ndarray, shape=(n_pairs,), distances between pairs
42
"""
43
44
def get_metric(self):
45
"""
46
Get metric function for use with scikit-learn algorithms.
47
48
Returns:
49
- metric: callable, function that computes distance between two 1D arrays
50
"""
51
52
def score_pairs(self, pairs):
53
"""
54
Legacy method for computing scores between pairs.
55
56
.. deprecated:: 0.7.0
57
Use pair_distance or pair_score instead.
58
"""
59
```
60
61
### MahalanobisMixin
62
63
Mixin class for algorithms that learn Mahalanobis distance metrics. Inherits from BaseMetricLearner and adds functionality specific to Mahalanobis metrics.
64
65
```python { .api }
66
class MahalanobisMixin(BaseMetricLearner):
67
def transform(self, X):
68
"""
69
Apply the learned linear transformation to data.
70
71
Parameters:
72
- X: array-like, shape=(n_samples, n_features), data to transform
73
74
Returns:
75
- X_transformed: ndarray, shape=(n_samples, n_components), transformed data
76
"""
77
78
def get_mahalanobis_matrix(self):
79
"""
80
Get the learned Mahalanobis matrix.
81
82
Returns:
83
- M: ndarray, shape=(n_features, n_features), Mahalanobis matrix
84
"""
85
86
def pair_distance(self, pairs):
87
"""
88
Compute Mahalanobis distance between pairs.
89
90
Parameters:
91
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
92
93
Returns:
94
- distances: ndarray, shape=(n_pairs,), Mahalanobis distances
95
"""
96
97
def pair_score(self, pairs):
98
"""
99
Compute similarity score (negative distance) between pairs.
100
101
Parameters:
102
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
103
104
Returns:
105
- scores: ndarray, shape=(n_pairs,), similarity scores
106
"""
107
```
108
109
**Attributes:**
110
111
```python { .api }
112
components_: ndarray, shape=(n_components, n_features)
113
"""The learned linear transformation matrix L such that M = L.T @ L"""
114
```
115
116
### Classification Mixins
117
118
Mixins that add classification capabilities for constraint-based learning scenarios.
119
120
#### PairsClassifierMixin
121
122
Adds binary classification capabilities for pair constraints.
123
124
```python { .api }
125
class _PairsClassifierMixin:
126
def predict(self, pairs):
127
"""
128
Predict similarity/dissimilarity for pairs.
129
130
Parameters:
131
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
132
133
Returns:
134
- predictions: ndarray, shape=(n_pairs,), predicted labels (+1 or -1)
135
"""
136
137
def decision_function(self, pairs):
138
"""
139
Compute decision function values for pairs.
140
141
Parameters:
142
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
143
144
Returns:
145
- decision_scores: ndarray, shape=(n_pairs,), decision function values
146
"""
147
148
def score(self, pairs, y):
149
"""
150
Compute accuracy score for pair predictions.
151
152
Parameters:
153
- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
154
- y: array-like, shape=(n_pairs,), true labels
155
156
Returns:
157
- accuracy: float, classification accuracy
158
"""
159
160
def set_threshold(self, threshold):
161
"""
162
Set classification threshold.
163
164
Parameters:
165
- threshold: float, decision threshold for classification
166
"""
167
168
def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy'):
169
"""
170
Calibrate classification threshold using validation data.
171
172
Parameters:
173
- pairs_valid: array-like, validation pairs
174
- y_valid: array-like, validation labels
175
- strategy: str, calibration strategy ('accuracy', 'f1', etc.)
176
"""
177
```
178
179
#### TripletsClassifierMixin
180
181
Adds classification capabilities for triplet constraints.
182
183
```python { .api }
184
class _TripletsClassifierMixin:
185
def predict(self, triplets):
186
"""
187
Predict triplet constraint satisfaction.
188
189
Parameters:
190
- triplets: array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3)
191
192
Returns:
193
- predictions: ndarray, shape=(n_triplets,), predicted constraint satisfaction
194
"""
195
196
def decision_function(self, triplets):
197
"""
198
Compute decision function for triplets.
199
200
Parameters:
201
- triplets: array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3)
202
203
Returns:
204
- decision_scores: ndarray, shape=(n_triplets,), decision scores
205
"""
206
207
def score(self, triplets, y):
208
"""
209
Compute accuracy for triplet predictions.
210
211
Parameters:
212
- triplets: array-like, triplet constraints
213
- y: array-like, true constraint labels
214
215
Returns:
216
- accuracy: float, classification accuracy
217
"""
218
```
219
220
#### QuadrupletsClassifierMixin
221
222
Adds classification capabilities for quadruplet constraints.
223
224
```python { .api }
225
class _QuadrupletsClassifierMixin:
226
def predict(self, quadruplets):
227
"""
228
Predict quadruplet constraint satisfaction.
229
230
Parameters:
231
- quadruplets: array-like, shape=(n_quadruplets, 4, n_features) or (n_quadruplets, 4)
232
233
Returns:
234
- predictions: ndarray, shape=(n_quadruplets,), predicted constraint satisfaction
235
"""
236
237
def decision_function(self, quadruplets):
238
"""
239
Compute decision function for quadruplets.
240
241
Parameters:
242
- quadruplets: array-like, shape=(n_quadruplets, 4, n_features) or (n_quadruplets, 4)
243
244
Returns:
245
- decision_scores: ndarray, shape=(n_quadruplets,), decision scores
246
"""
247
248
def score(self, quadruplets, y):
249
"""
250
Compute accuracy for quadruplet predictions.
251
252
Parameters:
253
- quadruplets: array-like, quadruplet constraints
254
- y: array-like, true constraint labels
255
256
Returns:
257
- accuracy: float, classification accuracy
258
"""
259
```
260
261
## Understanding the Class Hierarchy
262
263
The metric-learn package uses a clean inheritance hierarchy:
264
265
```python
266
# Base class for all metric learners
267
BaseMetricLearner (abstract)
268
│
269
├── MahalanobisMixin (concrete mixin)
270
│ │
271
│ ├── LMNN, NCA, LFDA (supervised algorithms)
272
│ ├── ITML, LSML, SDML, RCA, SCML (weakly-supervised algorithms)
273
│ ├── MMC (clustering algorithm)
274
│ └── Covariance (baseline algorithm)
275
│
276
└── MLKR (regression algorithm, does not use Mahalanobis)
277
278
# Classification mixins can be combined with base classes
279
_PairsClassifierMixin
280
_TripletsClassifierMixin
281
_QuadrupletsClassifierMixin
282
```
283
284
## Working with Base Classes
285
286
### Understanding the Metric Interface
287
288
All algorithms provide a consistent interface for computing distances and similarities:
289
290
```python
291
from metric_learn import LMNN, ITML
292
from sklearn.datasets import make_classification
293
import numpy as np
294
295
# Generate sample data
296
X, y = make_classification(n_samples=100, n_features=5, n_classes=3, random_state=42)
297
298
# Train different algorithms
299
lmnn = LMNN(n_neighbors=3)
300
lmnn.fit(X, y)
301
302
# Generate pairs and constraints for ITML
303
from metric_learn import Constraints
304
constraints = Constraints(y)
305
pos_pairs, neg_pairs = constraints.positive_negative_pairs(n_constraints=100)
306
pairs = np.vstack([pos_pairs, neg_pairs])
307
pair_labels = np.hstack([np.ones(len(pos_pairs)), -np.ones(len(neg_pairs))])
308
itml = ITML(preprocessor=X)
309
itml.fit(pairs, pair_labels)
310
311
# Both algorithms provide the same interface
312
test_pairs = [(0, 1), (2, 10), (5, 20)]
313
314
for name, algo in [('LMNN', lmnn), ('ITML', itml)]:
315
# Compute distances
316
distances = algo.pair_distance(test_pairs)
317
318
# Compute similarity scores
319
scores = algo.pair_score(test_pairs)
320
321
# Get metric function for scikit-learn
322
metric_func = algo.get_metric()
323
324
print(f"{name}: distances={distances[:2]}, scores={scores[:2]}")
325
```
326
327
### Using the Transform Interface
328
329
Algorithms that inherit from MahalanobisMixin provide data transformation:
330
331
```python
332
from metric_learn import LMNN, NCA, ITML
333
from sklearn.datasets import load_iris
334
335
X, y = load_iris(return_X_y=True)
336
337
# Train algorithms
338
lmnn = LMNN(n_neighbors=3)
339
lmnn.fit(X, y)
340
341
nca = NCA(max_iter=100)
342
nca.fit(X, y)
343
344
# All Mahalanobis-based algorithms support transform
345
for name, algo in [('LMNN', lmnn), ('NCA', nca)]:
346
# Transform data to learned metric space
347
X_transformed = algo.transform(X)
348
349
# Get the learned Mahalanobis matrix
350
M = algo.get_mahalanobis_matrix()
351
352
# Get linear transformation components
353
L = algo.components_
354
355
print(f"{name}: transformed shape={X_transformed.shape}, M shape={M.shape}")
356
print(f" Verification: M = L.T @ L = {np.allclose(M, L.T @ L)}")
357
```
358
359
### Custom Metric Learning Algorithm
360
361
Understanding the base classes enables implementing custom algorithms:
362
363
```python
364
from metric_learn.base_metric import MahalanobisMixin
365
from sklearn.base import TransformerMixin
366
import numpy as np
367
368
class CustomMetricLearner(MahalanobisMixin, TransformerMixin):
369
"""Example custom metric learning algorithm."""
370
371
def __init__(self, alpha=1.0, preprocessor=None):
372
super().__init__(preprocessor=preprocessor)
373
self.alpha = alpha
374
375
def fit(self, X, y):
376
"""Implement your metric learning algorithm here."""
377
# Example: simple covariance-based metric with regularization
378
X = self._prepare_inputs(X, y, type_of_inputs='classic')[0]
379
380
# Compute regularized covariance
381
cov = np.cov(X.T) + self.alpha * np.eye(X.shape[1])
382
383
# Use matrix decomposition for components_
384
eigenvals, eigenvecs = np.linalg.eigh(cov)
385
self.components_ = eigenvecs @ np.diag(np.sqrt(np.maximum(eigenvals, 1e-8)))
386
387
return self
388
389
# Usage
390
custom_learner = CustomMetricLearner(alpha=0.1)
391
custom_learner.fit(X, y)
392
X_transformed = custom_learner.transform(X)
393
print("Custom algorithm trained successfully!")
394
```