0
# Risk Control
1
2
Advanced risk control methods for multi-label classification scenarios, enabling control of precision and recall metrics with finite-sample guarantees. MAPIE implements conformal risk control procedures for complex prediction tasks where traditional conformal prediction may not suffice.
3
4
## Capabilities
5
6
### Precision-Recall Controller
7
8
Controls prediction risks in multi-label classification by providing finite-sample guarantees on precision or recall metrics. Implements conformal risk control (CRC), risk-controlling prediction sets (RCPS), and learn-then-test (LTT) methods.
9
10
```python { .api }
11
class PrecisionRecallController:
12
"""
13
Risk controller for multi-label classification with precision/recall guarantees.
14
15
Parameters:
16
- estimator: ClassifierMixin, base multi-label classifier
17
- metric_control: str, metric to control ("recall", "precision") (default: "recall")
18
- method: Optional[str], risk control method ("crc", "rcps", "ltt")
19
- n_jobs: Optional[int], number of parallel jobs
20
- random_state: Optional[int], random seed
21
- verbose: int, verbosity level (default: 0)
22
"""
23
def __init__(self, estimator=None, metric_control='recall', method=None, n_jobs=None, random_state=None, verbose=0): ...
24
25
def fit(self, X, y, conformalize_size=0.3):
26
"""
27
Fit the risk controller with training and conformalization data.
28
29
Parameters:
30
- X: ArrayLike, input features
31
- y: ArrayLike, multi-label targets (shape: n_samples x n_labels)
32
- conformalize_size: float, fraction of data for conformalization (default: 0.3)
33
34
Returns:
35
Self
36
"""
37
38
def partial_fit(self, X, y, _refit=False):
39
"""
40
Incrementally fit the risk controller.
41
42
Parameters:
43
- X: ArrayLike, input features
44
- y: ArrayLike, multi-label targets
45
- _refit: bool, whether to refit the base estimator (default: False)
46
47
Returns:
48
Self
49
"""
50
51
def predict(self, X, alpha=None, delta=None, bound=None):
52
"""
53
Predict with risk control guarantees.
54
55
Parameters:
56
- X: ArrayLike, test features
57
- alpha: Optional[float], risk level (between 0 and 1)
58
- delta: Optional[float], confidence level for the guarantee (between 0 and 1)
59
- bound: Optional[float], bound on the controlled metric
60
61
Returns:
62
Union[NDArray, Tuple[NDArray, NDArray]]: predictions or (predictions, bounds)
63
"""
64
65
# Key attributes after fitting
66
valid_methods: List[str] # Available risk control methods
67
lambdas: NDArray # Lambda threshold values
68
risks: ArrayLike # Risk values for each observation and threshold
69
```
70
71
## Usage Examples
72
73
### Basic Recall Control
74
75
```python
76
from mapie.risk_control import PrecisionRecallController
77
from sklearn.ensemble import RandomForestClassifier
78
import numpy as np
79
80
# Multi-label classification data
81
# y shape: (n_samples, n_labels) - binary matrix
82
X, y = load_multilabel_data()
83
84
# Create risk controller for recall
85
risk_controller = PrecisionRecallController(
86
estimator=RandomForestClassifier(n_estimators=100),
87
metric_control='recall',
88
method='crc', # Conformal Risk Control
89
random_state=42
90
)
91
92
# Fit with automatic train/conformalization split
93
risk_controller.fit(X, y, conformalize_size=0.3)
94
95
# Predict with recall guarantee
96
# Guarantee: P(recall >= 0.8) >= 0.9
97
y_pred = risk_controller.predict(
98
X_test,
99
alpha=0.2, # Risk level: 1-0.8 = 0.2
100
delta=0.1 # Confidence: 1-0.9 = 0.1
101
)
102
```
103
104
### Precision Control with RCPS Method
105
106
```python
107
# Risk-Controlling Prediction Sets method
108
risk_controller = PrecisionRecallController(
109
estimator=LogisticRegression(),
110
metric_control='precision',
111
method='rcps',
112
random_state=42
113
)
114
115
# Fit the controller
116
risk_controller.fit(X, y)
117
118
# Predict with precision control
119
y_pred, bounds = risk_controller.predict(
120
X_test,
121
alpha=0.1, # Allow 10% precision risk
122
delta=0.05, # 95% confidence
123
bound=0.85 # Target precision >= 85%
124
)
125
```
126
127
### Learn-Then-Test (LTT) Method
128
129
```python
130
from sklearn.multioutput import MultiOutputClassifier
131
from sklearn.svm import SVC
132
133
# LTT method for adaptive thresholding
134
risk_controller = PrecisionRecallController(
135
estimator=MultiOutputClassifier(SVC(probability=True)),
136
metric_control='recall',
137
method='ltt',
138
random_state=42
139
)
140
141
# Fit with larger conformalization set for LTT
142
risk_controller.fit(X, y, conformalize_size=0.5)
143
144
# Adaptive prediction with learned thresholds
145
y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)
146
```
147
148
## Risk Control Methods
149
150
### Conformal Risk Control (CRC)
151
152
Uses conformal prediction framework to provide distribution-free guarantees on risk metrics. Suitable for general multi-label scenarios.
153
154
```python
155
method="crc"
156
```
157
158
**Advantages:**
159
- Distribution-free guarantees
160
- Works with any base classifier
161
- Theoretical finite-sample coverage
162
163
**Use cases:**
164
- General multi-label classification
165
- When distributional assumptions cannot be made
166
167
### Risk-Controlling Prediction Sets (RCPS)
168
169
Provides prediction sets that control the expected value of a risk function. More flexible than traditional conformal methods.
170
171
```python
172
method="rcps"
173
```
174
175
**Advantages:**
176
- Controls expected risk rather than worst-case
177
- Can handle complex loss functions
178
- Adaptive set sizes
179
180
**Use cases:**
181
- When average risk control is sufficient
182
- Complex multi-label scenarios
183
- Large-scale applications
184
185
### Learn-Then-Test (LTT)
186
187
Two-stage approach that first learns optimal thresholds, then applies statistical testing for guarantees.
188
189
```python
190
method="ltt"
191
```
192
193
**Advantages:**
194
- Adaptive to data characteristics
195
- Good empirical performance
196
- Flexible threshold learning
197
198
**Use cases:**
199
- When threshold adaptation is important
200
- Sufficient conformalization data available
201
- Performance-critical applications
202
203
## Advanced Usage
204
205
### Custom Risk Functions
206
207
```python
208
# Define custom risk function
209
def custom_risk_fn(y_true, y_pred):
210
"""
211
Custom risk function for multi-label prediction.
212
213
Parameters:
214
- y_true: NDArray, true multi-label targets
215
- y_pred: NDArray, predicted multi-label outputs
216
217
Returns:
218
float: risk value
219
"""
220
# Example: weighted F1-score risk
221
f1_scores = f1_score(y_true, y_pred, average=None)
222
weights = np.array([0.3, 0.5, 0.2]) # Label weights
223
return 1 - np.average(f1_scores, weights=weights)
224
225
# Use with controller (requires custom implementation)
226
```
227
228
### Analyzing Risk Control Performance
229
230
```python
231
# Analyze lambda thresholds and risks
232
print(f"Available methods: {risk_controller.valid_methods}")
233
print(f"Lambda thresholds: {risk_controller.lambdas[:5]}") # First 5
234
print(f"Risk shape: {risk_controller.risks.shape}")
235
236
# Plot risk vs threshold
237
import matplotlib.pyplot as plt
238
239
plt.figure(figsize=(10, 6))
240
plt.plot(risk_controller.lambdas, np.mean(risk_controller.risks, axis=0))
241
plt.xlabel('Lambda Threshold')
242
plt.ylabel('Average Risk')
243
plt.title('Risk vs Threshold')
244
plt.show()
245
```
246
247
### Multi-Label Evaluation
248
249
```python
250
from sklearn.metrics import multilabel_confusion_matrix, classification_report
251
252
# Evaluate risk-controlled predictions
253
y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)
254
255
# Multi-label metrics
256
mcm = multilabel_confusion_matrix(y_test, y_pred)
257
print("Multi-label Confusion Matrices:")
258
for i, cm in enumerate(mcm):
259
print(f"Label {i}:")
260
print(cm)
261
262
# Classification report
263
report = classification_report(y_test, y_pred, target_names=label_names)
264
print("Classification Report:")
265
print(report)
266
```
267
268
### Handling Class Imbalance
269
270
```python
271
from sklearn.utils.class_weight import compute_class_weight
272
273
# Compute class weights for imbalanced multi-label data
274
class_weights = []
275
for i in range(y.shape[1]):
276
weights = compute_class_weight('balanced',
277
classes=np.unique(y[:, i]),
278
y=y[:, i])
279
class_weights.append({0: weights[0], 1: weights[1]})
280
281
# Use with base estimator
282
base_estimator = MultiOutputClassifier(
283
RandomForestClassifier(class_weight='balanced')
284
)
285
286
risk_controller = PrecisionRecallController(
287
estimator=base_estimator,
288
metric_control='recall',
289
method='crc'
290
)
291
```
292
293
## Utility Functions
294
295
Additional utility functions for implementing custom risk control procedures:
296
297
```python { .api }
298
# Risk computation functions
299
from mapie.control_risk.risks import compute_risk_recall, compute_risk_precision
300
301
def compute_risk_recall(y_true, y_pred, lambda_threshold): ...
302
def compute_risk_precision(y_true, y_pred, lambda_threshold): ...
303
304
# Learn-Then-Test procedures
305
from mapie.control_risk.ltt import ltt_procedure, find_lambda_control_star
306
307
def ltt_procedure(y_scores, y_true, lambda_values, alpha, delta): ...
308
def find_lambda_control_star(risk_values, lambda_values, alpha, delta): ...
309
310
# CRC/RCPS procedures
311
from mapie.control_risk.crc_rcps import get_r_hat_plus, find_lambda_star
312
313
def get_r_hat_plus(y_scores, y_true, lambda_values, alpha): ...
314
def find_lambda_star(risk_estimates, alpha, delta): ...
315
316
# Statistical tests
317
from mapie.control_risk.p_values import compute_hoeffdding_bentkus_p_value
318
319
def compute_hoeffdding_bentkus_p_value(observed_risk, bound, n_samples): ...
320
```
321
322
## Theoretical Guarantees
323
324
The risk control methods provide finite-sample guarantees:
325
326
- **CRC**: P(Risk ≤ α) ≥ 1 - δ with probability at least 1 - δ
327
- **RCPS**: E[Risk] ≤ α + O(√(log(1/δ)/n))
328
- **LTT**: Adaptive guarantees based on learned thresholds
329
330
Where:
331
- α: desired risk level
332
- δ: confidence parameter
333
- n: conformalization set size
334
335
These guarantees hold for any data distribution and any base classifier, making the methods truly distribution-free.