or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

base-classes.mdindex.mdspecialized-algorithms.mdsupervised-algorithms.mdutilities.mdweakly-supervised-algorithms.md

base-classes.mddocs/

0

# Base Classes and Mixins

1

2

Core abstract classes and mixins that define the metric learning API. Understanding these classes is essential for using metric-learn algorithms effectively and for implementing custom metric learning algorithms.

3

4

## Capabilities

5

6

### BaseMetricLearner

7

8

Abstract base class that defines the core interface for all metric learning algorithms. All metric learning algorithms in the package inherit from this class.

9

10

```python { .api }

11

class BaseMetricLearner(BaseEstimator):

12

def __init__(self, preprocessor=None):

13

"""

14

Base constructor for metric learners.

15

16

Parameters:

17

- preprocessor: array-like or callable, preprocessor to get data from indices

18

"""

19

20

def pair_score(self, pairs):

21

"""

22

Compute similarity score between pairs.

23

24

Parameters:

25

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2),

26

3D array of pairs or 2D array of indices

27

28

Returns:

29

- scores: ndarray, shape=(n_pairs,), similarity scores (higher = more similar)

30

"""

31

32

def pair_distance(self, pairs):

33

"""

34

Compute distance between pairs.

35

36

Parameters:

37

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2),

38

3D array of pairs or 2D array of indices

39

40

Returns:

41

- distances: ndarray, shape=(n_pairs,), distances between pairs

42

"""

43

44

def get_metric(self):

45

"""

46

Get metric function for use with scikit-learn algorithms.

47

48

Returns:

49

- metric: callable, function that computes distance between two 1D arrays

50

"""

51

52

def score_pairs(self, pairs):

53

"""

54

Legacy method for computing scores between pairs.

55

56

.. deprecated:: 0.7.0

57

Use pair_distance or pair_score instead.

58

"""

59

```

60

61

### MahalanobisMixin

62

63

Mixin class for algorithms that learn Mahalanobis distance metrics. Inherits from BaseMetricLearner and adds functionality specific to Mahalanobis metrics.

64

65

```python { .api }

66

class MahalanobisMixin(BaseMetricLearner):

67

def transform(self, X):

68

"""

69

Apply the learned linear transformation to data.

70

71

Parameters:

72

- X: array-like, shape=(n_samples, n_features), data to transform

73

74

Returns:

75

- X_transformed: ndarray, shape=(n_samples, n_components), transformed data

76

"""

77

78

def get_mahalanobis_matrix(self):

79

"""

80

Get the learned Mahalanobis matrix.

81

82

Returns:

83

- M: ndarray, shape=(n_features, n_features), Mahalanobis matrix

84

"""

85

86

def pair_distance(self, pairs):

87

"""

88

Compute Mahalanobis distance between pairs.

89

90

Parameters:

91

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)

92

93

Returns:

94

- distances: ndarray, shape=(n_pairs,), Mahalanobis distances

95

"""

96

97

def pair_score(self, pairs):

98

"""

99

Compute similarity score (negative distance) between pairs.

100

101

Parameters:

102

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)

103

104

Returns:

105

- scores: ndarray, shape=(n_pairs,), similarity scores

106

"""

107

```

108

109

**Attributes:**

110

111

```python { .api }

112

components_: ndarray, shape=(n_components, n_features)

113

"""The learned linear transformation matrix L such that M = L.T @ L"""

114

```

115

116

### Classification Mixins

117

118

Mixins that add classification capabilities for constraint-based learning scenarios.

119

120

#### PairsClassifierMixin

121

122

Adds binary classification capabilities for pair constraints.

123

124

```python { .api }

125

class _PairsClassifierMixin:

126

def predict(self, pairs):

127

"""

128

Predict similarity/dissimilarity for pairs.

129

130

Parameters:

131

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)

132

133

Returns:

134

- predictions: ndarray, shape=(n_pairs,), predicted labels (+1 or -1)

135

"""

136

137

def decision_function(self, pairs):

138

"""

139

Compute decision function values for pairs.

140

141

Parameters:

142

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)

143

144

Returns:

145

- decision_scores: ndarray, shape=(n_pairs,), decision function values

146

"""

147

148

def score(self, pairs, y):

149

"""

150

Compute accuracy score for pair predictions.

151

152

Parameters:

153

- pairs: array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)

154

- y: array-like, shape=(n_pairs,), true labels

155

156

Returns:

157

- accuracy: float, classification accuracy

158

"""

159

160

def set_threshold(self, threshold):

161

"""

162

Set classification threshold.

163

164

Parameters:

165

- threshold: float, decision threshold for classification

166

"""

167

168

def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy'):

169

"""

170

Calibrate classification threshold using validation data.

171

172

Parameters:

173

- pairs_valid: array-like, validation pairs

174

- y_valid: array-like, validation labels

175

- strategy: str, calibration strategy ('accuracy', 'f1', etc.)

176

"""

177

```

178

179

#### TripletsClassifierMixin

180

181

Adds classification capabilities for triplet constraints.

182

183

```python { .api }

184

class _TripletsClassifierMixin:

185

def predict(self, triplets):

186

"""

187

Predict triplet constraint satisfaction.

188

189

Parameters:

190

- triplets: array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3)

191

192

Returns:

193

- predictions: ndarray, shape=(n_triplets,), predicted constraint satisfaction

194

"""

195

196

def decision_function(self, triplets):

197

"""

198

Compute decision function for triplets.

199

200

Parameters:

201

- triplets: array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3)

202

203

Returns:

204

- decision_scores: ndarray, shape=(n_triplets,), decision scores

205

"""

206

207

def score(self, triplets, y):

208

"""

209

Compute accuracy for triplet predictions.

210

211

Parameters:

212

- triplets: array-like, triplet constraints

213

- y: array-like, true constraint labels

214

215

Returns:

216

- accuracy: float, classification accuracy

217

"""

218

```

219

220

#### QuadrupletsClassifierMixin

221

222

Adds classification capabilities for quadruplet constraints.

223

224

```python { .api }

225

class _QuadrupletsClassifierMixin:

226

def predict(self, quadruplets):

227

"""

228

Predict quadruplet constraint satisfaction.

229

230

Parameters:

231

- quadruplets: array-like, shape=(n_quadruplets, 4, n_features) or (n_quadruplets, 4)

232

233

Returns:

234

- predictions: ndarray, shape=(n_quadruplets,), predicted constraint satisfaction

235

"""

236

237

def decision_function(self, quadruplets):

238

"""

239

Compute decision function for quadruplets.

240

241

Parameters:

242

- quadruplets: array-like, shape=(n_quadruplets, 4, n_features) or (n_quadruplets, 4)

243

244

Returns:

245

- decision_scores: ndarray, shape=(n_quadruplets,), decision scores

246

"""

247

248

def score(self, quadruplets, y):

249

"""

250

Compute accuracy for quadruplet predictions.

251

252

Parameters:

253

- quadruplets: array-like, quadruplet constraints

254

- y: array-like, true constraint labels

255

256

Returns:

257

- accuracy: float, classification accuracy

258

"""

259

```

260

261

## Understanding the Class Hierarchy

262

263

The metric-learn package uses a clean inheritance hierarchy:

264

265

```python

266

# Base class for all metric learners

267

BaseMetricLearner (abstract)

268

269

├── MahalanobisMixin (concrete mixin)

270

│ │

271

│ ├── LMNN, NCA, LFDA (supervised algorithms)

272

│ ├── ITML, LSML, SDML, RCA, SCML (weakly-supervised algorithms)

273

│ ├── MMC (clustering algorithm)

274

│ └── Covariance (baseline algorithm)

275

276

└── MLKR (regression algorithm, does not use Mahalanobis)

277

278

# Classification mixins can be combined with base classes

279

_PairsClassifierMixin

280

_TripletsClassifierMixin

281

_QuadrupletsClassifierMixin

282

```

283

284

## Working with Base Classes

285

286

### Understanding the Metric Interface

287

288

All algorithms provide a consistent interface for computing distances and similarities:

289

290

```python

291

from metric_learn import LMNN, ITML

292

from sklearn.datasets import make_classification

293

import numpy as np

294

295

# Generate sample data

296

X, y = make_classification(n_samples=100, n_features=5, n_classes=3, random_state=42)

297

298

# Train different algorithms

299

lmnn = LMNN(n_neighbors=3)

300

lmnn.fit(X, y)

301

302

# Generate pairs and constraints for ITML

303

from metric_learn import Constraints

304

constraints = Constraints(y)

305

pos_pairs, neg_pairs = constraints.positive_negative_pairs(n_constraints=100)

306

pairs = np.vstack([pos_pairs, neg_pairs])

307

pair_labels = np.hstack([np.ones(len(pos_pairs)), -np.ones(len(neg_pairs))])

308

itml = ITML(preprocessor=X)

309

itml.fit(pairs, pair_labels)

310

311

# Both algorithms provide the same interface

312

test_pairs = [(0, 1), (2, 10), (5, 20)]

313

314

for name, algo in [('LMNN', lmnn), ('ITML', itml)]:

315

# Compute distances

316

distances = algo.pair_distance(test_pairs)

317

318

# Compute similarity scores

319

scores = algo.pair_score(test_pairs)

320

321

# Get metric function for scikit-learn

322

metric_func = algo.get_metric()

323

324

print(f"{name}: distances={distances[:2]}, scores={scores[:2]}")

325

```

326

327

### Using the Transform Interface

328

329

Algorithms that inherit from MahalanobisMixin provide data transformation:

330

331

```python

332

from metric_learn import LMNN, NCA, ITML

333

from sklearn.datasets import load_iris

334

335

X, y = load_iris(return_X_y=True)

336

337

# Train algorithms

338

lmnn = LMNN(n_neighbors=3)

339

lmnn.fit(X, y)

340

341

nca = NCA(max_iter=100)

342

nca.fit(X, y)

343

344

# All Mahalanobis-based algorithms support transform

345

for name, algo in [('LMNN', lmnn), ('NCA', nca)]:

346

# Transform data to learned metric space

347

X_transformed = algo.transform(X)

348

349

# Get the learned Mahalanobis matrix

350

M = algo.get_mahalanobis_matrix()

351

352

# Get linear transformation components

353

L = algo.components_

354

355

print(f"{name}: transformed shape={X_transformed.shape}, M shape={M.shape}")

356

print(f" Verification: M = L.T @ L = {np.allclose(M, L.T @ L)}")

357

```

358

359

### Custom Metric Learning Algorithm

360

361

Understanding the base classes enables implementing custom algorithms:

362

363

```python

364

from metric_learn.base_metric import MahalanobisMixin

365

from sklearn.base import TransformerMixin

366

import numpy as np

367

368

class CustomMetricLearner(MahalanobisMixin, TransformerMixin):

369

"""Example custom metric learning algorithm."""

370

371

def __init__(self, alpha=1.0, preprocessor=None):

372

super().__init__(preprocessor=preprocessor)

373

self.alpha = alpha

374

375

def fit(self, X, y):

376

"""Implement your metric learning algorithm here."""

377

# Example: simple covariance-based metric with regularization

378

X = self._prepare_inputs(X, y, type_of_inputs='classic')[0]

379

380

# Compute regularized covariance

381

cov = np.cov(X.T) + self.alpha * np.eye(X.shape[1])

382

383

# Use matrix decomposition for components_

384

eigenvals, eigenvecs = np.linalg.eigh(cov)

385

self.components_ = eigenvecs @ np.diag(np.sqrt(np.maximum(eigenvals, 1e-8)))

386

387

return self

388

389

# Usage

390

custom_learner = CustomMetricLearner(alpha=0.1)

391

custom_learner.fit(X, y)

392

X_transformed = custom_learner.transform(X)

393

print("Custom algorithm trained successfully!")

394

```