or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

calibration.mdclassification.mdindex.mdmetrics.mdregression.mdrisk-control.mdutils.md

risk-control.mddocs/

0

# Risk Control

1

2

Advanced risk control methods for multi-label classification scenarios, enabling control of precision and recall metrics with finite-sample guarantees. MAPIE implements conformal risk control procedures for complex prediction tasks where traditional conformal prediction may not suffice.

3

4

## Capabilities

5

6

### Precision-Recall Controller

7

8

Controls prediction risks in multi-label classification by providing finite-sample guarantees on precision or recall metrics. Implements conformal risk control (CRC), risk-controlling prediction sets (RCPS), and learn-then-test (LTT) methods.

9

10

```python { .api }

11

class PrecisionRecallController:

12

"""

13

Risk controller for multi-label classification with precision/recall guarantees.

14

15

Parameters:

16

- estimator: ClassifierMixin, base multi-label classifier

17

- metric_control: str, metric to control ("recall", "precision") (default: "recall")

18

- method: Optional[str], risk control method ("crc", "rcps", "ltt")

19

- n_jobs: Optional[int], number of parallel jobs

20

- random_state: Optional[int], random seed

21

- verbose: int, verbosity level (default: 0)

22

"""

23

def __init__(self, estimator=None, metric_control='recall', method=None, n_jobs=None, random_state=None, verbose=0): ...

24

25

def fit(self, X, y, conformalize_size=0.3):

26

"""

27

Fit the risk controller with training and conformalization data.

28

29

Parameters:

30

- X: ArrayLike, input features

31

- y: ArrayLike, multi-label targets (shape: n_samples x n_labels)

32

- conformalize_size: float, fraction of data for conformalization (default: 0.3)

33

34

Returns:

35

Self

36

"""

37

38

def partial_fit(self, X, y, _refit=False):

39

"""

40

Incrementally fit the risk controller.

41

42

Parameters:

43

- X: ArrayLike, input features

44

- y: ArrayLike, multi-label targets

45

- _refit: bool, whether to refit the base estimator (default: False)

46

47

Returns:

48

Self

49

"""

50

51

def predict(self, X, alpha=None, delta=None, bound=None):

52

"""

53

Predict with risk control guarantees.

54

55

Parameters:

56

- X: ArrayLike, test features

57

- alpha: Optional[float], risk level (between 0 and 1)

58

- delta: Optional[float], confidence level for the guarantee (between 0 and 1)

59

- bound: Optional[float], bound on the controlled metric

60

61

Returns:

62

Union[NDArray, Tuple[NDArray, NDArray]]: predictions or (predictions, bounds)

63

"""

64

65

# Key attributes after fitting

66

valid_methods: List[str] # Available risk control methods

67

lambdas: NDArray # Lambda threshold values

68

risks: ArrayLike # Risk values for each observation and threshold

69

```

70

71

## Usage Examples

72

73

### Basic Recall Control

74

75

```python

76

from mapie.risk_control import PrecisionRecallController

77

from sklearn.ensemble import RandomForestClassifier

78

import numpy as np

79

80

# Multi-label classification data

81

# y shape: (n_samples, n_labels) - binary matrix

82

X, y = load_multilabel_data()

83

84

# Create risk controller for recall

85

risk_controller = PrecisionRecallController(

86

estimator=RandomForestClassifier(n_estimators=100),

87

metric_control='recall',

88

method='crc', # Conformal Risk Control

89

random_state=42

90

)

91

92

# Fit with automatic train/conformalization split

93

risk_controller.fit(X, y, conformalize_size=0.3)

94

95

# Predict with recall guarantee

96

# Guarantee: P(recall >= 0.8) >= 0.9

97

y_pred = risk_controller.predict(

98

X_test,

99

alpha=0.2, # Risk level: 1-0.8 = 0.2

100

delta=0.1 # Confidence: 1-0.9 = 0.1

101

)

102

```

103

104

### Precision Control with RCPS Method

105

106

```python

107

# Risk-Controlling Prediction Sets method

108

risk_controller = PrecisionRecallController(

109

estimator=LogisticRegression(),

110

metric_control='precision',

111

method='rcps',

112

random_state=42

113

)

114

115

# Fit the controller

116

risk_controller.fit(X, y)

117

118

# Predict with precision control

119

y_pred, bounds = risk_controller.predict(

120

X_test,

121

alpha=0.1, # Allow 10% precision risk

122

delta=0.05, # 95% confidence

123

bound=0.85 # Target precision >= 85%

124

)

125

```

126

127

### Learn-Then-Test (LTT) Method

128

129

```python

130

from sklearn.multioutput import MultiOutputClassifier

131

from sklearn.svm import SVC

132

133

# LTT method for adaptive thresholding

134

risk_controller = PrecisionRecallController(

135

estimator=MultiOutputClassifier(SVC(probability=True)),

136

metric_control='recall',

137

method='ltt',

138

random_state=42

139

)

140

141

# Fit with larger conformalization set for LTT

142

risk_controller.fit(X, y, conformalize_size=0.5)

143

144

# Adaptive prediction with learned thresholds

145

y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)

146

```

147

148

## Risk Control Methods

149

150

### Conformal Risk Control (CRC)

151

152

Uses conformal prediction framework to provide distribution-free guarantees on risk metrics. Suitable for general multi-label scenarios.

153

154

```python

155

method="crc"

156

```

157

158

**Advantages:**

159

- Distribution-free guarantees

160

- Works with any base classifier

161

- Theoretical finite-sample coverage

162

163

**Use cases:**

164

- General multi-label classification

165

- When distributional assumptions cannot be made

166

167

### Risk-Controlling Prediction Sets (RCPS)

168

169

Provides prediction sets that control the expected value of a risk function. More flexible than traditional conformal methods.

170

171

```python

172

method="rcps"

173

```

174

175

**Advantages:**

176

- Controls expected risk rather than worst-case

177

- Can handle complex loss functions

178

- Adaptive set sizes

179

180

**Use cases:**

181

- When average risk control is sufficient

182

- Complex multi-label scenarios

183

- Large-scale applications

184

185

### Learn-Then-Test (LTT)

186

187

Two-stage approach that first learns optimal thresholds, then applies statistical testing for guarantees.

188

189

```python

190

method="ltt"

191

```

192

193

**Advantages:**

194

- Adaptive to data characteristics

195

- Good empirical performance

196

- Flexible threshold learning

197

198

**Use cases:**

199

- When threshold adaptation is important

200

- Sufficient conformalization data available

201

- Performance-critical applications

202

203

## Advanced Usage

204

205

### Custom Risk Functions

206

207

```python

208

# Define custom risk function

209

def custom_risk_fn(y_true, y_pred):

210

"""

211

Custom risk function for multi-label prediction.

212

213

Parameters:

214

- y_true: NDArray, true multi-label targets

215

- y_pred: NDArray, predicted multi-label outputs

216

217

Returns:

218

float: risk value

219

"""

220

# Example: weighted F1-score risk

221

f1_scores = f1_score(y_true, y_pred, average=None)

222

weights = np.array([0.3, 0.5, 0.2]) # Label weights

223

return 1 - np.average(f1_scores, weights=weights)

224

225

# Use with controller (requires custom implementation)

226

```

227

228

### Analyzing Risk Control Performance

229

230

```python

231

# Analyze lambda thresholds and risks

232

print(f"Available methods: {risk_controller.valid_methods}")

233

print(f"Lambda thresholds: {risk_controller.lambdas[:5]}") # First 5

234

print(f"Risk shape: {risk_controller.risks.shape}")

235

236

# Plot risk vs threshold

237

import matplotlib.pyplot as plt

238

239

plt.figure(figsize=(10, 6))

240

plt.plot(risk_controller.lambdas, np.mean(risk_controller.risks, axis=0))

241

plt.xlabel('Lambda Threshold')

242

plt.ylabel('Average Risk')

243

plt.title('Risk vs Threshold')

244

plt.show()

245

```

246

247

### Multi-Label Evaluation

248

249

```python

250

from sklearn.metrics import multilabel_confusion_matrix, classification_report

251

252

# Evaluate risk-controlled predictions

253

y_pred = risk_controller.predict(X_test, alpha=0.1, delta=0.1)

254

255

# Multi-label metrics

256

mcm = multilabel_confusion_matrix(y_test, y_pred)

257

print("Multi-label Confusion Matrices:")

258

for i, cm in enumerate(mcm):

259

print(f"Label {i}:")

260

print(cm)

261

262

# Classification report

263

report = classification_report(y_test, y_pred, target_names=label_names)

264

print("Classification Report:")

265

print(report)

266

```

267

268

### Handling Class Imbalance

269

270

```python

271

from sklearn.utils.class_weight import compute_class_weight

272

273

# Compute class weights for imbalanced multi-label data

274

class_weights = []

275

for i in range(y.shape[1]):

276

weights = compute_class_weight('balanced',

277

classes=np.unique(y[:, i]),

278

y=y[:, i])

279

class_weights.append({0: weights[0], 1: weights[1]})

280

281

# Use with base estimator

282

base_estimator = MultiOutputClassifier(

283

RandomForestClassifier(class_weight='balanced')

284

)

285

286

risk_controller = PrecisionRecallController(

287

estimator=base_estimator,

288

metric_control='recall',

289

method='crc'

290

)

291

```

292

293

## Utility Functions

294

295

Additional utility functions for implementing custom risk control procedures:

296

297

```python { .api }

298

# Risk computation functions

299

from mapie.control_risk.risks import compute_risk_recall, compute_risk_precision

300

301

def compute_risk_recall(y_true, y_pred, lambda_threshold): ...

302

def compute_risk_precision(y_true, y_pred, lambda_threshold): ...

303

304

# Learn-Then-Test procedures

305

from mapie.control_risk.ltt import ltt_procedure, find_lambda_control_star

306

307

def ltt_procedure(y_scores, y_true, lambda_values, alpha, delta): ...

308

def find_lambda_control_star(risk_values, lambda_values, alpha, delta): ...

309

310

# CRC/RCPS procedures

311

from mapie.control_risk.crc_rcps import get_r_hat_plus, find_lambda_star

312

313

def get_r_hat_plus(y_scores, y_true, lambda_values, alpha): ...

314

def find_lambda_star(risk_estimates, alpha, delta): ...

315

316

# Statistical tests

317

from mapie.control_risk.p_values import compute_hoeffdding_bentkus_p_value

318

319

def compute_hoeffdding_bentkus_p_value(observed_risk, bound, n_samples): ...

320

```

321

322

## Theoretical Guarantees

323

324

The risk control methods provide finite-sample guarantees:

325

326

- **CRC**: P(Risk ≤ α) ≥ 1 - δ with probability at least 1 - δ

327

- **RCPS**: E[Risk] ≤ α + O(√(log(1/δ)/n))

328

- **LTT**: Adaptive guarantees based on learned thresholds

329

330

Where:

331

- α: desired risk level

332

- δ: confidence parameter

333

- n: conformalization set size

334

335

These guarantees hold for any data distribution and any base classifier, making the methods truly distribution-free.