A scikit-learn-compatible module for estimating prediction intervals using conformal prediction methods.
Advanced risk control methods for multi-label classification scenarios, enabling control of precision and recall metrics with finite-sample guarantees. MAPIE implements conformal risk control procedures for complex prediction tasks where traditional conformal prediction may not suffice.
Controls prediction risks in multi-label classification by providing finite-sample guarantees on precision or recall metrics. Implements conformal risk control (CRC), risk-controlling prediction sets (RCPS), and learn-then-test (LTT) methods.
class PrecisionRecallController:
"""
Risk controller for multi-label classification with precision/recall guarantees.
Parameters:
- estimator: ClassifierMixin, base multi-label classifier
- metric_control: str, metric to control ("recall", "precision") (default: "recall")
- method: Optional[str], risk control method ("crc", "rcps", "ltt")
- n_jobs: Optional[int], number of parallel jobs
- random_state: Optional[int], random seed
- verbose: int, verbosity level (default: 0)
"""
def __init__(self, estimator=None, metric_control='recall', method=None, n_jobs=None, random_state=None, verbose=0): ...
def fit(self, X, y, conformalize_size=0.3):
"""
Fit the risk controller with training and conformalization data.
Parameters:
- X: ArrayLike, input features
- y: ArrayLike, multi-label targets (shape: n_samples x n_labels)
- conformalize_size: float, fraction of data for conformalization (default: 0.3)
Returns:
Self
"""
def partial_fit(self, X, y, _refit=False):
"""
Incrementally fit the risk controller.
Parameters:
- X: ArrayLike, input features
- y: ArrayLike, multi-label targets
- _refit: bool, whether to refit the base estimator (default: False)
Returns:
Self
"""
def predict(self, X, alpha=None, delta=None, bound=None):
"""
Predict with risk control guarantees.
Parameters:
- X: ArrayLike, test features
- alpha: Optional[float], risk level (between 0 and 1)
- delta: Optional[float], confidence level for the guarantee (between 0 and 1)
- bound: Optional[float], bound on the controlled metric
Returns:
Union[NDArray, Tuple[NDArray, NDArray]]: predictions or (predictions, bounds)
"""
# Key attributes after fitting
valid_methods: List[str] # Available risk control methods
lambdas: NDArray # Lambda threshold values
risks: ArrayLike # Risk values for each observation and thresholdfrom mapie.risk_control import PrecisionRecallController
from sklearn.ensemble import RandomForestClassifier
import numpy as np
# Multi-label classification data
# y shape: (n_samples, n_labels) - binary matrix
X, y = load_multilabel_data()
# Create risk controller for recall
risk_controller = PrecisionRecallController(
estimator=RandomForestClassifier(n_estimators=100),
metric_control='recall',
method='crc', # Conformal Risk Control
random_state=42
)
# Fit with automatic train/conformalization split
risk_controller.fit(X, y, conformalize_size=0.3)
# Predict with recall guarantee
# Guarantee: P(recall >= 0.8) >= 0.9
y_pred = risk_controller.predict(
X_test,
alpha=0.2, # Risk level: 1-0.8 = 0.2
delta=0.1 # Confidence: 1-0.9 = 0.1
)# Risk-Controlling Prediction Sets method
risk_controller = PrecisionRecallController(
estimator=LogisticRegression(),
metric_control='precision',
method='rcps',
random_state=42
)
# Fit the controller
risk_controller.fit(X, y)
# Predict with precision control
y_pred, bounds = risk_controller.predict(
X_test,
alpha=0.1, # Allow 10% precision risk
delta=0.05, # 95% confidence
bound=0.85 # Target precision >= 85%
)from sklearn.multioutput import MultiOutputClassifier
from sklearn.svm import SVC
# LTT method for adaptive thresholding
risk_controller = PrecisionRecallController(
estimator=MultiOutputClassifier(SVC(probability=True)),
metric_control='recall',
method='ltt',
random_state=42
)
# Fit with larger conformalization set for LTT
risk_controller.fit(X, y, conformalize_size=0.5)
# Adaptive prediction with learned thresholds
y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)Uses conformal prediction framework to provide distribution-free guarantees on risk metrics. Suitable for general multi-label scenarios.
method="crc"Advantages:
Use cases:
Provides prediction sets that control the expected value of a risk function. More flexible than traditional conformal methods.
method="rcps"Advantages:
Use cases:
Two-stage approach that first learns optimal thresholds, then applies statistical testing for guarantees.
method="ltt"Advantages:
Use cases:
# Define custom risk function
def custom_risk_fn(y_true, y_pred):
"""
Custom risk function for multi-label prediction.
Parameters:
- y_true: NDArray, true multi-label targets
- y_pred: NDArray, predicted multi-label outputs
Returns:
float: risk value
"""
# Example: weighted F1-score risk
f1_scores = f1_score(y_true, y_pred, average=None)
weights = np.array([0.3, 0.5, 0.2]) # Label weights
return 1 - np.average(f1_scores, weights=weights)
# Use with controller (requires custom implementation)# Analyze lambda thresholds and risks
print(f"Available methods: {risk_controller.valid_methods}")
print(f"Lambda thresholds: {risk_controller.lambdas[:5]}") # First 5
print(f"Risk shape: {risk_controller.risks.shape}")
# Plot risk vs threshold
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(risk_controller.lambdas, np.mean(risk_controller.risks, axis=0))
plt.xlabel('Lambda Threshold')
plt.ylabel('Average Risk')
plt.title('Risk vs Threshold')
plt.show()from sklearn.metrics import multilabel_confusion_matrix, classification_report
# Evaluate risk-controlled predictions
y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)
# Multi-label metrics
mcm = multilabel_confusion_matrix(y_test, y_pred)
print("Multi-label Confusion Matrices:")
for i, cm in enumerate(mcm):
print(f"Label {i}:")
print(cm)
# Classification report
report = classification_report(y_test, y_pred, target_names=label_names)
print("Classification Report:")
print(report)from sklearn.utils.class_weight import compute_class_weight
# Compute class weights for imbalanced multi-label data
class_weights = []
for i in range(y.shape[1]):
weights = compute_class_weight('balanced',
classes=np.unique(y[:, i]),
y=y[:, i])
class_weights.append({0: weights[0], 1: weights[1]})
# Use with base estimator
base_estimator = MultiOutputClassifier(
RandomForestClassifier(class_weight='balanced')
)
risk_controller = PrecisionRecallController(
estimator=base_estimator,
metric_control='recall',
method='crc'
)Additional utility functions for implementing custom risk control procedures:
# Risk computation functions
from mapie.control_risk.risks import compute_risk_recall, compute_risk_precision
def compute_risk_recall(y_true, y_pred, lambda_threshold): ...
def compute_risk_precision(y_true, y_pred, lambda_threshold): ...
# Learn-Then-Test procedures
from mapie.control_risk.ltt import ltt_procedure, find_lambda_control_star
def ltt_procedure(y_scores, y_true, lambda_values, alpha, delta): ...
def find_lambda_control_star(risk_values, lambda_values, alpha, delta): ...
# CRC/RCPS procedures
from mapie.control_risk.crc_rcps import get_r_hat_plus, find_lambda_star
def get_r_hat_plus(y_scores, y_true, lambda_values, alpha): ...
def find_lambda_star(risk_estimates, alpha, delta): ...
# Statistical tests
from mapie.control_risk.p_values import compute_hoeffdding_bentkus_p_value
def compute_hoeffdding_bentkus_p_value(observed_risk, bound, n_samples): ...The risk control methods provide finite-sample guarantees:
Where:
These guarantees hold for any data distribution and any base classifier, making the methods truly distribution-free.
Install with Tessl CLI
npx tessl i tessl/pypi-mapie