0
# Postprocessing
1
2
Postprocessing techniques that adjust trained model outputs to satisfy fairness constraints without retraining. These methods optimize decision thresholds across groups to achieve fairness while working with any pre-trained classifier.
3
4
## Capabilities
5
6
### ThresholdOptimizer
7
8
Optimizes decision thresholds to satisfy fairness constraints by adjusting the classification boundary for different groups. This approach can achieve various fairness criteria without retraining the underlying model.
9
10
```python { .api }
11
class ThresholdOptimizer:
12
def __init__(self, *, estimator=None, constraints="demographic_parity",
13
objective="accuracy_score", grid_size=1000,
14
flip=False, prefit=False, predict_method="auto"):
15
"""
16
Optimize decision thresholds to satisfy fairness constraints.
17
18
Parameters:
19
- estimator: sklearn estimator, pre-trained classifier (optional if prefit=True)
20
- constraints: str or Moment, fairness constraint to satisfy
21
Options: "demographic_parity", "equalized_odds", "equal_opportunity"
22
- objective: str or callable, objective function to optimize
23
Options: "accuracy_score", "balanced_accuracy_score", "selection_rate", "roc_auc_score"
24
- grid_size: int, number of threshold values to consider
25
- flip: bool, whether to allow label flipping
26
- prefit: bool, whether estimator is already fitted
27
- predict_method: str, method to use for generating predictions ("auto", "predict_proba", "decision_function")
28
"""
29
30
def fit(self, X, y, *, sensitive_features, sample_weight=None, **kwargs):
31
"""
32
Fit the threshold optimizer.
33
34
Parameters:
35
- X: array-like, feature matrix
36
- y: array-like, true target values
37
- sensitive_features: array-like, sensitive feature values
38
- sample_weight: array-like, optional sample weights
39
- **kwargs: additional arguments passed to estimator.fit() if not prefit
40
41
Returns:
42
self
43
"""
44
45
def predict(self, X, *, sensitive_features, random_state=None):
46
"""
47
Make predictions using optimized thresholds.
48
49
Parameters:
50
- X: array-like, feature matrix
51
- sensitive_features: array-like, sensitive feature values for test data
52
- random_state: int, random state for reproducible results
53
54
Returns:
55
array-like: Binary predictions using optimized thresholds
56
"""
57
58
@property
59
def interpolated_thresholder_(self):
60
"""The fitted threshold interpolation object."""
61
62
@property
63
def solution_(self):
64
"""Details of the optimization solution."""
65
```
66
67
#### Usage Example
68
69
```python
70
from fairlearn.postprocessing import ThresholdOptimizer
71
from sklearn.linear_model import LogisticRegression
72
from sklearn.model_selection import train_test_split
73
74
# Train a base model
75
base_model = LogisticRegression()
76
base_model.fit(X_train, y_train)
77
78
# Create threshold optimizer for demographic parity
79
threshold_optimizer = ThresholdOptimizer(
80
estimator=base_model,
81
constraints="demographic_parity",
82
objective="accuracy_score",
83
prefit=True # Model is already trained
84
)
85
86
# Fit the threshold optimizer
87
threshold_optimizer.fit(
88
X_train, y_train,
89
sensitive_features=sensitive_features_train
90
)
91
92
# Make fair predictions
93
fair_predictions = threshold_optimizer.predict(
94
X_test,
95
sensitive_features=sensitive_features_test
96
)
97
```
98
99
### Plotting ThresholdOptimizer Results
100
101
Visualize the trade-offs discovered by the threshold optimizer.
102
103
```python { .api }
104
def plot_threshold_optimizer(threshold_optimizer, *, ax=None, show_plot=True):
105
"""
106
Plot the trade-off curve from threshold optimization.
107
108
Parameters:
109
- threshold_optimizer: fitted ThresholdOptimizer object
110
- ax: matplotlib axis, optional axis to plot on
111
- show_plot: bool, whether to display the plot
112
113
Returns:
114
matplotlib figure object
115
"""
116
```
117
118
#### Plotting Example
119
120
```python
121
from fairlearn.postprocessing import plot_threshold_optimizer
122
import matplotlib.pyplot as plt
123
124
# After fitting threshold optimizer
125
plot_threshold_optimizer(threshold_optimizer)
126
plt.title("Fairness-Accuracy Trade-off")
127
plt.show()
128
```
129
130
## Constraint Options
131
132
### Demographic Parity
133
134
Ensures equal positive prediction rates across groups.
135
136
```python
137
# Using string constraint
138
optimizer = ThresholdOptimizer(
139
constraints="demographic_parity",
140
objective="accuracy_score"
141
)
142
143
# The constraint ensures P(Y_hat=1 | A=a) is equal for all groups a
144
```
145
146
### Equalized Odds
147
148
Ensures equal true positive and false positive rates across groups.
149
150
```python
151
optimizer = ThresholdOptimizer(
152
constraints="equalized_odds",
153
objective="balanced_accuracy_score"
154
)
155
156
# The constraint ensures both:
157
# - P(Y_hat=1 | Y=1, A=a) is equal for all groups a
158
# - P(Y_hat=1 | Y=0, A=a) is equal for all groups a
159
```
160
161
### Equal Opportunity
162
163
Ensures equal true positive rates across groups.
164
165
```python
166
optimizer = ThresholdOptimizer(
167
constraints="equal_opportunity",
168
objective="accuracy_score"
169
)
170
171
# The constraint ensures P(Y_hat=1 | Y=1, A=a) is equal for all groups a
172
```
173
174
## Objective Functions
175
176
### Accuracy-based Objectives
177
178
```python
179
# Standard accuracy
180
ThresholdOptimizer(objective="accuracy_score")
181
182
# Balanced accuracy (average of recall for each class)
183
ThresholdOptimizer(objective="balanced_accuracy_score")
184
```
185
186
### Selection Rate Objective
187
188
```python
189
# Optimize for overall selection rate
190
ThresholdOptimizer(objective="selection_rate")
191
```
192
193
### ROC AUC Objective
194
195
```python
196
# Optimize for area under ROC curve
197
ThresholdOptimizer(objective="roc_auc_score")
198
```
199
200
### Custom Objectives
201
202
```python
203
def custom_objective(y_true, y_pred):
204
"""Custom objective function."""
205
return some_metric(y_true, y_pred)
206
207
ThresholdOptimizer(objective=custom_objective)
208
```
209
210
## Advanced Usage
211
212
### Working with Probability Predictions
213
214
The ThresholdOptimizer works with models that output probabilities:
215
216
```python
217
from sklearn.ensemble import RandomForestClassifier
218
219
# Train probabilistic model
220
rf_model = RandomForestClassifier(n_estimators=100)
221
rf_model.fit(X_train, y_train)
222
223
# Threshold optimizer will use predict_proba internally
224
optimizer = ThresholdOptimizer(
225
estimator=rf_model,
226
constraints="demographic_parity",
227
prefit=True
228
)
229
230
optimizer.fit(X_train, y_train, sensitive_features=A_train)
231
```
232
233
### Multiple Sensitive Features
234
235
Handle multiple sensitive attributes simultaneously:
236
237
```python
238
# Sensitive features as DataFrame with multiple columns
239
sensitive_features = pd.DataFrame({
240
'gender': ['M', 'F', 'M', 'F'],
241
'age_group': ['young', 'old', 'young', 'old']
242
})
243
244
optimizer = ThresholdOptimizer(constraints="demographic_parity")
245
optimizer.fit(X_train, y_train, sensitive_features=sensitive_features)
246
247
# Predictions will account for all sensitive feature combinations
248
predictions = optimizer.predict(X_test, sensitive_features=sensitive_features_test)
249
```
250
251
### Controlling Randomization
252
253
For deterministic results when using randomized thresholding:
254
255
```python
256
predictions = optimizer.predict(
257
X_test,
258
sensitive_features=A_test,
259
random_state=42
260
)
261
```
262
263
### Accessing Optimization Details
264
265
```python
266
# Get details about the optimization solution
267
solution = optimizer.solution_
268
print(f"Objective value: {solution['objective']}")
269
print(f"Constraint violation: {solution['constraint_violation']}")
270
271
# Access the interpolated thresholder
272
thresholder = optimizer.interpolated_thresholder_
273
print(f"Thresholds: {thresholder.interpolation_dict}")
274
```
275
276
## Integration with Assessment
277
278
Combine with fairness assessment tools to evaluate results:
279
280
```python
281
from fairlearn.metrics import MetricFrame, demographic_parity_difference
282
283
# Get predictions from optimized model
284
optimized_predictions = threshold_optimizer.predict(
285
X_test, sensitive_features=A_test
286
)
287
288
# Assess fairness
289
fairness_frame = MetricFrame(
290
metrics={
291
'accuracy': lambda y, p: (y == p).mean(),
292
'selection_rate': lambda y, p: p.mean()
293
},
294
y_true=y_test,
295
y_pred=optimized_predictions,
296
sensitive_features=A_test
297
)
298
299
print("Fairness assessment:")
300
print(fairness_frame.by_group)
301
print(f"Demographic parity difference: {demographic_parity_difference(y_test, optimized_predictions, sensitive_features=A_test)}")
302
```
303
304
## Best Practices
305
306
### Model Selection
307
308
1. **Base Model Quality**: Start with a well-performing base model
309
2. **Probability Calibration**: Ensure base model produces well-calibrated probabilities
310
3. **Validation**: Use separate validation set for threshold optimization
311
312
```python
313
# Recommended workflow
314
from sklearn.model_selection import train_test_split
315
from sklearn.calibration import CalibratedClassifierCV
316
317
# Split data into train/validation/test
318
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4)
319
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5)
320
321
# Train and calibrate base model
322
base_model = LogisticRegression()
323
calibrated_model = CalibratedClassifierCV(base_model, cv=3)
324
calibrated_model.fit(X_train, y_train)
325
326
# Optimize thresholds on validation set
327
optimizer = ThresholdOptimizer(
328
estimator=calibrated_model,
329
constraints="demographic_parity",
330
prefit=True
331
)
332
optimizer.fit(X_val, y_val, sensitive_features=A_val)
333
334
# Final evaluation on test set
335
final_predictions = optimizer.predict(X_test, sensitive_features=A_test)
336
```
337
338
### Constraint Selection
339
340
Choose appropriate constraints based on your fairness requirements:
341
342
- **Demographic Parity**: When equal representation is important
343
- **Equal Opportunity**: When avoiding discrimination against qualified individuals is key
344
- **Equalized Odds**: When both false positive and false negative rates matter
345
346
### Performance Monitoring
347
348
Monitor both fairness and accuracy after threshold optimization:
349
350
```python
351
def evaluate_postprocessed_model(optimizer, X_test, y_test, A_test):
352
predictions = optimizer.predict(X_test, sensitive_features=A_test)
353
354
# Accuracy metrics
355
accuracy = (y_test == predictions).mean()
356
357
# Fairness metrics
358
dp_diff = demographic_parity_difference(y_test, predictions, sensitive_features=A_test)
359
360
return {
361
'accuracy': accuracy,
362
'demographic_parity_difference': dp_diff,
363
'predictions': predictions
364
}
365
```