or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

base-classes.mdindex.mdspecialized-algorithms.mdsupervised-algorithms.mdutilities.mdweakly-supervised-algorithms.md

index.mddocs/

0

# Metric-Learn

1

2

Python implementations of metric learning algorithms that are fully compatible with scikit-learn's API. Metric-learn provides efficient implementations of several popular supervised and weakly-supervised metric learning algorithms as part of the scikit-learn-contrib ecosystem.

3

4

## Package Information

5

6

- **Package Name**: metric-learn

7

- **Language**: Python

8

- **Installation**: `pip install metric-learn`

9

- **Dependencies**: Python 3.6+, numpy>=1.11.0, scipy>=0.17.0, scikit-learn>=0.21.3

10

11

## Core Imports

12

13

```python

14

import metric_learn

15

```

16

17

Common import pattern for specific algorithms:

18

19

```python

20

from metric_learn import LMNN, NCA, ITML, LSML

21

```

22

23

Import utility classes:

24

25

```python

26

from metric_learn import Constraints

27

```

28

29

## Basic Usage

30

31

```python

32

from metric_learn import LMNN

33

from sklearn.datasets import load_iris

34

from sklearn.model_selection import train_test_split

35

36

# Load sample data

37

X, y = load_iris(return_X_y=True)

38

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

39

40

# Learn a metric with LMNN

41

lmnn = LMNN(n_neighbors=3, learn_rate=1e-6)

42

lmnn.fit(X_train, y_train)

43

44

# Transform data to the learned metric space

45

X_train_transformed = lmnn.transform(X_train)

46

X_test_transformed = lmnn.transform(X_test)

47

48

# Get the learned Mahalanobis matrix

49

mahalanobis_matrix = lmnn.get_mahalanobis_matrix()

50

51

# Compute distances between pairs

52

pairs = [[0, 1], [2, 3]] # indices of pairs

53

distances = lmnn.pair_distance(X_test[pairs])

54

```

55

56

## Architecture

57

58

Metric-learn follows a hierarchical class structure based on scikit-learn patterns:

59

60

- **BaseMetricLearner**: Abstract base class defining the core metric learning interface

61

- **MahalanobisMixin**: Mixin for algorithms that learn Mahalanobis distance metrics

62

- **Classifier Mixins**: Support for pair/triplet/quadruplet classification tasks

63

- **Algorithm Classes**: Concrete implementations of specific metric learning algorithms

64

65

All algorithms implement the scikit-learn API with `fit()`, `transform()`, and specialized methods for computing distances and similarities between data points.

66

67

## Capabilities

68

69

### Supervised Metric Learning Algorithms

70

71

Core supervised algorithms that learn from labeled training data to optimize distance metrics for classification tasks.

72

73

```python { .api }

74

class LMNN(MahalanobisMixin, TransformerMixin):

75

def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, verbose=False, preprocessor=None, n_components=None, random_state=None): ...

76

def fit(self, X, y): ...

77

78

class NCA(MahalanobisMixin, TransformerMixin):

79

def __init__(self, init='auto', n_components=None, max_iter=100, tol=None, verbose=False, preprocessor=None, random_state=None): ...

80

def fit(self, X, y): ...

81

82

class LFDA(MahalanobisMixin, TransformerMixin):

83

def __init__(self, n_components=None, k=None, embedding_type='weighted', preprocessor=None): ...

84

def fit(self, X, y): ...

85

```

86

87

[Supervised Algorithms](./supervised-algorithms.md)

88

89

### Weakly-Supervised Learning Algorithms

90

91

Algorithms that learn from constraints (pairs, triplets, quadruplets) rather than explicit class labels.

92

93

```python { .api }

94

class ITML(MahalanobisMixin, TransformerMixin):

95

def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, prior='identity', verbose=False, preprocessor=None, random_state=None): ...

96

def fit(self, pairs, y): ...

97

98

class LSML(MahalanobisMixin, TransformerMixin):

99

def __init__(self, tol=1e-3, max_iter=1000, verbose=False, preprocessor=None, random_state=None): ...

100

def fit(self, pairs, y): ...

101

102

class SDML(MahalanobisMixin, TransformerMixin):

103

def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, preprocessor=None, random_state=None): ...

104

def fit(self, pairs, y): ...

105

106

class RCA(MahalanobisMixin, TransformerMixin):

107

def __init__(self, n_components=None, preprocessor=None): ...

108

def fit(self, X, chunks): ...

109

110

class SCML(MahalanobisMixin, TransformerMixin):

111

def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, random_state=None): ...

112

def fit(self, triplets): ...

113

```

114

115

[Weakly-Supervised Algorithms](./weakly-supervised-algorithms.md)

116

117

### Supervised Variants of Weakly-Supervised Algorithms

118

119

Supervised versions that automatically generate constraints from class labels, combining the convenience of supervised learning with constraint-based optimization.

120

121

```python { .api }

122

class ITML_Supervised(MahalanobisMixin, TransformerMixin):

123

def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, n_constraints=None, prior='identity', verbose=False, preprocessor=None, random_state=None): ...

124

def fit(self, X, y): ...

125

126

class LSML_Supervised(MahalanobisMixin, TransformerMixin):

127

def __init__(self, tol=1e-3, max_iter=1000, prior='identity', n_constraints=None, verbose=False, preprocessor=None, random_state=None): ...

128

def fit(self, X, y): ...

129

130

class SDML_Supervised(MahalanobisMixin, TransformerMixin):

131

def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, n_constraints=None, preprocessor=None, random_state=None): ...

132

def fit(self, X, y): ...

133

134

class RCA_Supervised(MahalanobisMixin, TransformerMixin):

135

def __init__(self, n_components=None, n_chunks=100, chunk_size=2, preprocessor=None, random_state=None): ...

136

def fit(self, X, y): ...

137

138

class MMC_Supervised(MahalanobisMixin, TransformerMixin):

139

def __init__(self, init='identity', max_iter=100, max_proj=10000, convergence_threshold=1e-3, n_constraints=None, diagonal=False, diagonal_c=1.0, verbose=False, preprocessor=None, random_state=None): ...

140

def fit(self, X, y): ...

141

142

class SCML_Supervised(MahalanobisMixin, TransformerMixin):

143

def __init__(self, beta=1e-5, basis='lda', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, random_state=None): ...

144

def fit(self, X, y): ...

145

```

146

147

### Specialized Algorithms

148

149

Algorithms designed for specific use cases like clustering and kernel regression.

150

151

```python { .api }

152

class MLKR(BaseMetricLearner, TransformerMixin):

153

def __init__(self, init='auto', alpha=0.1, max_iter=1000, preprocessor=None, random_state=None): ...

154

def fit(self, X, y): ...

155

156

class MMC(MahalanobisMixin, TransformerMixin):

157

def __init__(self, init='identity', max_iter=100, max_proj=10000, convergence_threshold=1e-3, num_constraints=None, diagonal=False, diagonal_c=1.0, verbose=False, preprocessor=None, random_state=None): ...

158

def fit(self, X, y): ...

159

160

class Covariance(MahalanobisMixin, TransformerMixin):

161

def __init__(self, preprocessor=None): ...

162

def fit(self, X, y=None): ...

163

```

164

165

[Specialized Algorithms](./specialized-algorithms.md)

166

167

### Utility Classes and Constraints

168

169

Helper classes for generating constraints and working with metric learning data.

170

171

```python { .api }

172

class Constraints:

173

def __init__(self, partial_labels): ...

174

def positive_negative_pairs(self, n_constraints, same_length=False, random_state=None): ...

175

def chunks(self, n_chunks=100, chunk_size=2, random_state=None): ...

176

def generate_knntriplets(self, X, k_genuine, k_impostor): ...

177

```

178

179

[Utilities](./utilities.md)

180

181

### Base Classes and Mixins

182

183

Core abstract classes and mixins that define the metric learning API.

184

185

```python { .api }

186

class BaseMetricLearner(BaseEstimator):

187

def __init__(self, preprocessor=None): ...

188

def pair_score(self, pairs): ...

189

def pair_distance(self, pairs): ...

190

def get_metric(self): ...

191

192

class MahalanobisMixin(BaseMetricLearner):

193

def transform(self, X): ...

194

def get_mahalanobis_matrix(self): ...

195

```

196

197

[Base Classes](./base-classes.md)