A Python package to assess and improve fairness of machine learning models
npx @tessl/cli install tessl/pypi-fairlearn@0.12.00
# Fairlearn
1
2
A comprehensive Python library for assessing and improving fairness of machine learning models. Fairlearn provides algorithms for mitigating unfairness in AI systems, focusing on group fairness through allocation harms and quality-of-service harms, with tools for both assessment and mitigation of bias.
3
4
## Package Information
5
6
- **Package Name**: fairlearn
7
- **Language**: Python
8
- **Installation**: `pip install fairlearn`
9
10
## Core Imports
11
12
```python
13
import fairlearn
14
```
15
16
Common imports for specific functionality:
17
18
```python
19
# For fairness assessment
20
from fairlearn.metrics import MetricFrame, demographic_parity_difference, equalized_odds_difference
21
22
# For fairness mitigation algorithms
23
from fairlearn.reductions import ExponentiatedGradient, GridSearch, DemographicParity, EqualizedOdds
24
25
# For preprocessing and postprocessing
26
from fairlearn.preprocessing import CorrelationRemover
27
from fairlearn.postprocessing import ThresholdOptimizer
28
29
# For datasets
30
from fairlearn.datasets import fetch_adult, fetch_acs_income
31
```
32
33
## Basic Usage
34
35
```python
36
import pandas as pd
37
from sklearn.model_selection import train_test_split
38
from sklearn.linear_model import LogisticRegression
39
from fairlearn.metrics import MetricFrame, demographic_parity_difference, equalized_odds_difference
40
from fairlearn.reductions import ExponentiatedGradient, DemographicParity
41
from fairlearn.datasets import fetch_adult
42
43
# Load sample dataset
44
X, y, sensitive_features = fetch_adult(return_X_y=True, as_frame=True)
45
X_train, X_test, y_train, y_test, A_train, A_test = train_test_split(
46
X, y, sensitive_features, test_size=0.3, random_state=42, stratify=y
47
)
48
49
# Train baseline model
50
baseline_model = LogisticRegression(random_state=42)
51
baseline_model.fit(X_train, y_train)
52
y_pred_baseline = baseline_model.predict(X_test)
53
54
# Assess fairness using MetricFrame
55
mf = MetricFrame(
56
metrics={"accuracy": lambda y_true, y_pred: (y_true == y_pred).mean(),
57
"selection_rate": lambda y_true, y_pred: y_pred.mean()},
58
y_true=y_test,
59
y_pred=y_pred_baseline,
60
sensitive_features=A_test
61
)
62
63
print("Baseline model fairness:")
64
print(mf.by_group)
65
print(f"Demographic parity difference: {demographic_parity_difference(y_test, y_pred_baseline, sensitive_features=A_test)}")
66
67
# Mitigate unfairness using ExponentiatedGradient
68
constraint = DemographicParity()
69
mitigator = ExponentiatedGradient(baseline_model, constraint)
70
mitigator.fit(X_train, y_train, sensitive_features=A_train)
71
y_pred_mitigated = mitigator.predict(X_test)
72
73
print(f"Mitigated demographic parity difference: {demographic_parity_difference(y_test, y_pred_mitigated, sensitive_features=A_test)}")
74
```
75
76
## Architecture
77
78
Fairlearn is structured around four main components:
79
80
- **Assessment** (`fairlearn.metrics`): Tools to measure fairness through disaggregated metrics and fairness-specific functions
81
- **Mitigation**: Multiple approaches to improve fairness:
82
- **Reduction methods** (`fairlearn.reductions`): Cast fairness as constrained optimization problems
83
- **Preprocessing** (`fairlearn.preprocessing`): Transform features to reduce correlation with sensitive attributes
84
- **Postprocessing** (`fairlearn.postprocessing`): Adjust model outputs to satisfy fairness constraints
85
- **Adversarial** (`fairlearn.adversarial`): Neural network approaches using adversarial training
86
- **Datasets** (`fairlearn.datasets`): Standard datasets for benchmarking fairness algorithms
87
- **Utilities** (`fairlearn.utils`): Supporting functions for data processing and validation
88
89
This design enables comprehensive fairness workflows from assessment through mitigation, supporting various fairness definitions and constraints to address both allocation harms (differences in quality or allocation of opportunities/resources) and quality-of-service harms (differences in quality of service across groups).
90
91
## Capabilities
92
93
### Fairness Assessment
94
95
Comprehensive tools for measuring fairness through disaggregated metrics across sensitive groups. The MetricFrame class provides the core functionality for computing and analyzing fairness metrics, while specialized fairness functions measure specific fairness criteria.
96
97
```python { .api }
98
class MetricFrame:
99
def __init__(self, metrics, y_true, y_pred, *, sensitive_features=None,
100
control_features=None, sample_params=None): ...
101
102
def demographic_parity_difference(y_true, y_pred, *, sensitive_features,
103
sample_weight=None): ...
104
def equalized_odds_difference(y_true, y_pred, *, sensitive_features,
105
sample_weight=None): ...
106
def equal_opportunity_difference(y_true, y_pred, *, sensitive_features,
107
sample_weight=None): ...
108
```
109
110
[Fairness Assessment](./assessment.md)
111
112
### Reduction Algorithms
113
114
In-processing mitigation techniques that cast fairness constraints as Lagrange multipliers in constrained optimization problems. These algorithms retrain models to satisfy fairness constraints while maintaining predictive performance.
115
116
```python { .api }
117
class ExponentiatedGradient:
118
def __init__(self, estimator, constraints, *, eps=0.01, T=50, nu=None,
119
eta_mul=2.0, run_linprog_step=True, sample_weight_name="sample_weight"): ...
120
def fit(self, X, y, *, sensitive_features, sample_weight=None): ...
121
def predict(self, X): ...
122
123
class GridSearch:
124
def __init__(self, estimator, constraints, *, selection_rule="tradeoff",
125
constraint_weight=0.5, grid_size=10, grid_limit=2.0,
126
grid_offset=None, sample_weight_name="sample_weight"): ...
127
```
128
129
[Reduction Algorithms](./reductions.md)
130
131
### Preprocessing
132
133
Preprocessing techniques that transform features to reduce correlation with sensitive attributes, addressing fairness at the data preparation stage.
134
135
```python { .api }
136
class CorrelationRemover:
137
def __init__(self, *, sensitive_feature_ids, alpha=1.0): ...
138
def fit(self, X, y=None): ...
139
def transform(self, X): ...
140
def fit_transform(self, X, y=None): ...
141
```
142
143
[Preprocessing](./preprocessing.md)
144
145
### Postprocessing
146
147
Postprocessing techniques that adjust trained model outputs to satisfy fairness constraints without retraining, optimizing decision thresholds across groups.
148
149
```python { .api }
150
class ThresholdOptimizer:
151
def __init__(self, *, estimator=None, constraints="demographic_parity",
152
objective="accuracy_score", grid_size=1000,
153
flip=False, plot=False, prefit=False): ...
154
def fit(self, X, y, *, sensitive_features, sample_weight=None, **kwargs): ...
155
def predict(self, X, *, sensitive_features, random_state=None): ...
156
```
157
158
[Postprocessing](./postprocessing.md)
159
160
### Adversarial Training
161
162
Neural network-based approaches using adversarial training to learn fair representations while maintaining predictive utility.
163
164
```python { .api }
165
class AdversarialFairnessClassifier:
166
def __init__(self, backend="torch", *, predictor_model=None, adversary_model=None,
167
alpha=1.0, epochs=1, batch_size=32, shuffle=True, progress_updates=None,
168
skip_validation=False, callbacks=None, random_state=None): ...
169
def fit(self, X, y, *, sensitive_features, sample_weight=None): ...
170
def predict(self, X): ...
171
def predict_proba(self, X): ...
172
173
class AdversarialFairnessRegressor:
174
def __init__(self, backend="torch", *, predictor_model=None, adversary_model=None,
175
alpha=1.0, epochs=1, batch_size=32, shuffle=True, progress_updates=None,
176
skip_validation=False, callbacks=None, random_state=None): ...
177
```
178
179
[Adversarial Training](./adversarial.md)
180
181
### Datasets
182
183
Standard datasets commonly used for fairness research and benchmarking, with consistent interfaces and built-in sensitive feature identification.
184
185
```python { .api }
186
def fetch_adult(*, cache=True, data_home=None, as_frame=True, return_X_y=False): ...
187
def fetch_acs_income(*, cache=True, data_home=None, as_frame=True, return_X_y=False,
188
state="CA", year=2018, with_nulls=False,
189
optimization="mem", accept_download=False): ...
190
def fetch_bank_marketing(*, cache=True, data_home=None, as_frame=True, return_X_y=False): ...
191
def fetch_boston(*, cache=True, data_home=None, as_frame=True, return_X_y=False, warn=True): ...
192
```
193
194
[Datasets](./datasets.md)
195
196
## Common Types
197
198
```python { .api }
199
# Fairness constraint types for reduction algorithms
200
class Moment: ...
201
class ClassificationMoment(Moment): ...
202
class LossMoment(Moment): ...
203
204
# Common constraint implementations
205
class DemographicParity(ClassificationMoment): ...
206
class EqualizedOdds(ClassificationMoment): ...
207
class TruePositiveRateParity(ClassificationMoment): ...
208
class ErrorRateParity(ClassificationMoment): ...
209
210
# Loss functions for bounded group loss constraints
211
class SquareLoss: ...
212
class AbsoluteLoss: ...
213
class ZeroOneLoss(AbsoluteLoss): ...
214
215
# Custom warning for dataset fairness issues
216
class DataFairnessWarning(UserWarning): ...
217
```