CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-mapie

A scikit-learn-compatible module for estimating prediction intervals using conformal prediction methods.

Overview
Eval results
Files

risk-control.mddocs/

Risk Control

Advanced risk control methods for multi-label classification scenarios, enabling control of precision and recall metrics with finite-sample guarantees. MAPIE implements conformal risk control procedures for complex prediction tasks where traditional conformal prediction may not suffice.

Capabilities

Precision-Recall Controller

Controls prediction risks in multi-label classification by providing finite-sample guarantees on precision or recall metrics. Implements conformal risk control (CRC), risk-controlling prediction sets (RCPS), and learn-then-test (LTT) methods.

class PrecisionRecallController:
    """
    Risk controller for multi-label classification with precision/recall guarantees.

    Parameters:
    - estimator: ClassifierMixin, base multi-label classifier
    - metric_control: str, metric to control ("recall", "precision") (default: "recall")
    - method: Optional[str], risk control method ("crc", "rcps", "ltt")
    - n_jobs: Optional[int], number of parallel jobs
    - random_state: Optional[int], random seed
    - verbose: int, verbosity level (default: 0)
    """
    def __init__(self, estimator=None, metric_control='recall', method=None, n_jobs=None, random_state=None, verbose=0): ...

    def fit(self, X, y, conformalize_size=0.3):
        """
        Fit the risk controller with training and conformalization data.

        Parameters:
        - X: ArrayLike, input features
        - y: ArrayLike, multi-label targets (shape: n_samples x n_labels)
        - conformalize_size: float, fraction of data for conformalization (default: 0.3)

        Returns:
        Self
        """

    def partial_fit(self, X, y, _refit=False):
        """
        Incrementally fit the risk controller.

        Parameters:
        - X: ArrayLike, input features
        - y: ArrayLike, multi-label targets
        - _refit: bool, whether to refit the base estimator (default: False)

        Returns:
        Self
        """

    def predict(self, X, alpha=None, delta=None, bound=None):
        """
        Predict with risk control guarantees.

        Parameters:
        - X: ArrayLike, test features
        - alpha: Optional[float], risk level (between 0 and 1)
        - delta: Optional[float], confidence level for the guarantee (between 0 and 1)
        - bound: Optional[float], bound on the controlled metric

        Returns:
        Union[NDArray, Tuple[NDArray, NDArray]]: predictions or (predictions, bounds)
        """

    # Key attributes after fitting
    valid_methods: List[str]  # Available risk control methods
    lambdas: NDArray  # Lambda threshold values
    risks: ArrayLike  # Risk values for each observation and threshold

Usage Examples

Basic Recall Control

from mapie.risk_control import PrecisionRecallController
from sklearn.ensemble import RandomForestClassifier
import numpy as np

# Multi-label classification data
# y shape: (n_samples, n_labels) - binary matrix
X, y = load_multilabel_data()

# Create risk controller for recall
risk_controller = PrecisionRecallController(
    estimator=RandomForestClassifier(n_estimators=100),
    metric_control='recall',
    method='crc',  # Conformal Risk Control
    random_state=42
)

# Fit with automatic train/conformalization split
risk_controller.fit(X, y, conformalize_size=0.3)

# Predict with recall guarantee
# Guarantee: P(recall >= 0.8) >= 0.9
y_pred = risk_controller.predict(
    X_test,
    alpha=0.2,   # Risk level: 1-0.8 = 0.2
    delta=0.1    # Confidence: 1-0.9 = 0.1
)

Precision Control with RCPS Method

# Risk-Controlling Prediction Sets method
risk_controller = PrecisionRecallController(
    estimator=LogisticRegression(),
    metric_control='precision',
    method='rcps',
    random_state=42
)

# Fit the controller
risk_controller.fit(X, y)

# Predict with precision control
y_pred, bounds = risk_controller.predict(
    X_test,
    alpha=0.1,   # Allow 10% precision risk
    delta=0.05,  # 95% confidence
    bound=0.85   # Target precision >= 85%
)

Learn-Then-Test (LTT) Method

from sklearn.multioutput import MultiOutputClassifier
from sklearn.svm import SVC

# LTT method for adaptive thresholding
risk_controller = PrecisionRecallController(
    estimator=MultiOutputClassifier(SVC(probability=True)),
    metric_control='recall',
    method='ltt',
    random_state=42
)

# Fit with larger conformalization set for LTT
risk_controller.fit(X, y, conformalize_size=0.5)

# Adaptive prediction with learned thresholds
y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)

Risk Control Methods

Conformal Risk Control (CRC)

Uses conformal prediction framework to provide distribution-free guarantees on risk metrics. Suitable for general multi-label scenarios.

method="crc"

Advantages:

  • Distribution-free guarantees
  • Works with any base classifier
  • Theoretical finite-sample coverage

Use cases:

  • General multi-label classification
  • When distributional assumptions cannot be made

Risk-Controlling Prediction Sets (RCPS)

Provides prediction sets that control the expected value of a risk function. More flexible than traditional conformal methods.

method="rcps"

Advantages:

  • Controls expected risk rather than worst-case
  • Can handle complex loss functions
  • Adaptive set sizes

Use cases:

  • When average risk control is sufficient
  • Complex multi-label scenarios
  • Large-scale applications

Learn-Then-Test (LTT)

Two-stage approach that first learns optimal thresholds, then applies statistical testing for guarantees.

method="ltt"

Advantages:

  • Adaptive to data characteristics
  • Good empirical performance
  • Flexible threshold learning

Use cases:

  • When threshold adaptation is important
  • Sufficient conformalization data available
  • Performance-critical applications

Advanced Usage

Custom Risk Functions

# Define custom risk function
def custom_risk_fn(y_true, y_pred):
    """
    Custom risk function for multi-label prediction.

    Parameters:
    - y_true: NDArray, true multi-label targets
    - y_pred: NDArray, predicted multi-label outputs

    Returns:
    float: risk value
    """
    # Example: weighted F1-score risk
    f1_scores = f1_score(y_true, y_pred, average=None)
    weights = np.array([0.3, 0.5, 0.2])  # Label weights
    return 1 - np.average(f1_scores, weights=weights)

# Use with controller (requires custom implementation)

Analyzing Risk Control Performance

# Analyze lambda thresholds and risks
print(f"Available methods: {risk_controller.valid_methods}")
print(f"Lambda thresholds: {risk_controller.lambdas[:5]}")  # First 5
print(f"Risk shape: {risk_controller.risks.shape}")

# Plot risk vs threshold
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(risk_controller.lambdas, np.mean(risk_controller.risks, axis=0))
plt.xlabel('Lambda Threshold')
plt.ylabel('Average Risk')
plt.title('Risk vs Threshold')
plt.show()

Multi-Label Evaluation

from sklearn.metrics import multilabel_confusion_matrix, classification_report

# Evaluate risk-controlled predictions
y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)

# Multi-label metrics
mcm = multilabel_confusion_matrix(y_test, y_pred)
print("Multi-label Confusion Matrices:")
for i, cm in enumerate(mcm):
    print(f"Label {i}:")
    print(cm)

# Classification report
report = classification_report(y_test, y_pred, target_names=label_names)
print("Classification Report:")
print(report)

Handling Class Imbalance

from sklearn.utils.class_weight import compute_class_weight

# Compute class weights for imbalanced multi-label data
class_weights = []
for i in range(y.shape[1]):
    weights = compute_class_weight('balanced',
                                  classes=np.unique(y[:, i]),
                                  y=y[:, i])
    class_weights.append({0: weights[0], 1: weights[1]})

# Use with base estimator
base_estimator = MultiOutputClassifier(
    RandomForestClassifier(class_weight='balanced')
)

risk_controller = PrecisionRecallController(
    estimator=base_estimator,
    metric_control='recall',
    method='crc'
)

Utility Functions

Additional utility functions for implementing custom risk control procedures:

# Risk computation functions
from mapie.control_risk.risks import compute_risk_recall, compute_risk_precision

def compute_risk_recall(y_true, y_pred, lambda_threshold): ...
def compute_risk_precision(y_true, y_pred, lambda_threshold): ...

# Learn-Then-Test procedures
from mapie.control_risk.ltt import ltt_procedure, find_lambda_control_star

def ltt_procedure(y_scores, y_true, lambda_values, alpha, delta): ...
def find_lambda_control_star(risk_values, lambda_values, alpha, delta): ...

# CRC/RCPS procedures
from mapie.control_risk.crc_rcps import get_r_hat_plus, find_lambda_star

def get_r_hat_plus(y_scores, y_true, lambda_values, alpha): ...
def find_lambda_star(risk_estimates, alpha, delta): ...

# Statistical tests
from mapie.control_risk.p_values import compute_hoeffdding_bentkus_p_value

def compute_hoeffdding_bentkus_p_value(observed_risk, bound, n_samples): ...

Theoretical Guarantees

The risk control methods provide finite-sample guarantees:

  • CRC: P(Risk ≤ α) ≥ 1 - δ with probability at least 1 - δ
  • RCPS: E[Risk] ≤ α + O(√(log(1/δ)/n))
  • LTT: Adaptive guarantees based on learned thresholds

Where:

  • α: desired risk level
  • δ: confidence parameter
  • n: conformalization set size

These guarantees hold for any data distribution and any base classifier, making the methods truly distribution-free.

Install with Tessl CLI

npx tessl i tessl/pypi-mapie

docs

calibration.md

classification.md

index.md

metrics.md

regression.md

risk-control.md

utils.md

tile.json