Python implementations of metric learning algorithms
npx @tessl/cli install tessl/pypi-metric-learn@0.7.0Python implementations of metric learning algorithms that are fully compatible with scikit-learn's API. Metric-learn provides efficient implementations of several popular supervised and weakly-supervised metric learning algorithms as part of the scikit-learn-contrib ecosystem.
pip install metric-learnimport metric_learnCommon import pattern for specific algorithms:
from metric_learn import LMNN, NCA, ITML, LSMLImport utility classes:
from metric_learn import Constraintsfrom metric_learn import LMNN
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load sample data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Learn a metric with LMNN
lmnn = LMNN(n_neighbors=3, learn_rate=1e-6)
lmnn.fit(X_train, y_train)
# Transform data to the learned metric space
X_train_transformed = lmnn.transform(X_train)
X_test_transformed = lmnn.transform(X_test)
# Get the learned Mahalanobis matrix
mahalanobis_matrix = lmnn.get_mahalanobis_matrix()
# Compute distances between pairs
pairs = [[0, 1], [2, 3]] # indices of pairs
distances = lmnn.pair_distance(X_test[pairs])Metric-learn follows a hierarchical class structure based on scikit-learn patterns:
All algorithms implement the scikit-learn API with fit(), transform(), and specialized methods for computing distances and similarities between data points.
Core supervised algorithms that learn from labeled training data to optimize distance metrics for classification tasks.
class LMNN(MahalanobisMixin, TransformerMixin):
def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, verbose=False, preprocessor=None, n_components=None, random_state=None): ...
def fit(self, X, y): ...
class NCA(MahalanobisMixin, TransformerMixin):
def __init__(self, init='auto', n_components=None, max_iter=100, tol=None, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class LFDA(MahalanobisMixin, TransformerMixin):
def __init__(self, n_components=None, k=None, embedding_type='weighted', preprocessor=None): ...
def fit(self, X, y): ...Algorithms that learn from constraints (pairs, triplets, quadruplets) rather than explicit class labels.
class ITML(MahalanobisMixin, TransformerMixin):
def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, prior='identity', verbose=False, preprocessor=None, random_state=None): ...
def fit(self, pairs, y): ...
class LSML(MahalanobisMixin, TransformerMixin):
def __init__(self, tol=1e-3, max_iter=1000, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, pairs, y): ...
class SDML(MahalanobisMixin, TransformerMixin):
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, preprocessor=None, random_state=None): ...
def fit(self, pairs, y): ...
class RCA(MahalanobisMixin, TransformerMixin):
def __init__(self, n_components=None, preprocessor=None): ...
def fit(self, X, chunks): ...
class SCML(MahalanobisMixin, TransformerMixin):
def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, triplets): ...Supervised versions that automatically generate constraints from class labels, combining the convenience of supervised learning with constraint-based optimization.
class ITML_Supervised(MahalanobisMixin, TransformerMixin):
def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, n_constraints=None, prior='identity', verbose=False, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class LSML_Supervised(MahalanobisMixin, TransformerMixin):
def __init__(self, tol=1e-3, max_iter=1000, prior='identity', n_constraints=None, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class SDML_Supervised(MahalanobisMixin, TransformerMixin):
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, n_constraints=None, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class RCA_Supervised(MahalanobisMixin, TransformerMixin):
def __init__(self, n_components=None, n_chunks=100, chunk_size=2, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class MMC_Supervised(MahalanobisMixin, TransformerMixin):
def __init__(self, init='identity', max_iter=100, max_proj=10000, convergence_threshold=1e-3, n_constraints=None, diagonal=False, diagonal_c=1.0, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class SCML_Supervised(MahalanobisMixin, TransformerMixin):
def __init__(self, beta=1e-5, basis='lda', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...Algorithms designed for specific use cases like clustering and kernel regression.
class MLKR(BaseMetricLearner, TransformerMixin):
def __init__(self, init='auto', alpha=0.1, max_iter=1000, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class MMC(MahalanobisMixin, TransformerMixin):
def __init__(self, init='identity', max_iter=100, max_proj=10000, convergence_threshold=1e-3, num_constraints=None, diagonal=False, diagonal_c=1.0, verbose=False, preprocessor=None, random_state=None): ...
def fit(self, X, y): ...
class Covariance(MahalanobisMixin, TransformerMixin):
def __init__(self, preprocessor=None): ...
def fit(self, X, y=None): ...Helper classes for generating constraints and working with metric learning data.
class Constraints:
def __init__(self, partial_labels): ...
def positive_negative_pairs(self, n_constraints, same_length=False, random_state=None): ...
def chunks(self, n_chunks=100, chunk_size=2, random_state=None): ...
def generate_knntriplets(self, X, k_genuine, k_impostor): ...Core abstract classes and mixins that define the metric learning API.
class BaseMetricLearner(BaseEstimator):
def __init__(self, preprocessor=None): ...
def pair_score(self, pairs): ...
def pair_distance(self, pairs): ...
def get_metric(self): ...
class MahalanobisMixin(BaseMetricLearner):
def transform(self, X): ...
def get_mahalanobis_matrix(self): ...