0
# Imbalanced-learn
1
2
A comprehensive Python toolbox for dealing with imbalanced datasets in machine learning. Provides over-sampling, under-sampling, and combination methods that integrate seamlessly with scikit-learn's API and pipeline system, enabling fair and robust machine learning models on class-imbalanced data.
3
4
## Package Information
5
6
- **Package Name**: imbalanced-learn
7
- **Language**: Python
8
- **Installation**: `pip install imbalanced-learn`
9
10
## Core Imports
11
12
```python
13
import imblearn
14
```
15
16
Common imports for sampling algorithms:
17
18
```python
19
from imblearn.over_sampling import SMOTE, ADASYN, RandomOverSampler
20
from imblearn.under_sampling import RandomUnderSampler, TomekLinks, EditedNearestNeighbours
21
from imblearn.combine import SMOTEENN, SMOTETomek
22
```
23
24
Pipeline integration:
25
26
```python
27
from imblearn.pipeline import Pipeline, make_pipeline
28
```
29
30
## Basic Usage
31
32
```python
33
from imblearn.over_sampling import SMOTE
34
from imblearn.under_sampling import RandomUnderSampler
35
from imblearn.pipeline import Pipeline
36
from sklearn.ensemble import RandomForestClassifier
37
from sklearn.model_selection import train_test_split
38
import numpy as np
39
40
# Create imbalanced dataset example
41
X = np.random.rand(1000, 4)
42
y = np.random.choice([0, 1], size=1000, p=[0.9, 0.1]) # 90% class 0, 10% class 1
43
44
# Split data
45
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
46
47
# Apply SMOTE oversampling
48
smote = SMOTE(random_state=42)
49
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
50
51
# Or use in pipeline
52
pipeline = Pipeline([
53
('sampling', SMOTE(random_state=42)),
54
('classifier', RandomForestClassifier(random_state=42))
55
])
56
pipeline.fit(X_train, y_train)
57
predictions = pipeline.predict(X_test)
58
```
59
60
## Architecture
61
62
Imbalanced-learn follows scikit-learn's design patterns and API conventions:
63
64
- **Base Classes**: All samplers inherit from `BaseSampler` and implement `fit_resample()` method
65
- **Sampling Strategy**: Consistent `sampling_strategy` parameter across all algorithms
66
- **Pipeline Integration**: Native support for scikit-learn's `Pipeline` with custom `Pipeline` class
67
- **Validation**: Comprehensive input validation following scikit-learn patterns
68
- **Randomization**: Consistent `random_state` parameter for reproducibility
69
70
## Capabilities
71
72
### Over-sampling Methods
73
74
Methods to increase minority class samples by generating synthetic examples or duplicating existing ones. Includes SMOTE variants, ADASYN, and random over-sampling approaches.
75
76
```python { .api }
77
class SMOTE:
78
def __init__(self, sampling_strategy='auto', random_state=None, k_neighbors=5, n_jobs=None): ...
79
def fit_resample(self, X, y): ...
80
81
class ADASYN:
82
def __init__(self, sampling_strategy='auto', random_state=None, n_neighbors=5, n_jobs=None): ...
83
def fit_resample(self, X, y): ...
84
85
class RandomOverSampler:
86
def __init__(self, sampling_strategy='auto', random_state=None, shrinkage=None): ...
87
def fit_resample(self, X, y): ...
88
```
89
90
[Over-sampling Methods](./over-sampling.md)
91
92
### Under-sampling Methods
93
94
Methods to reduce majority class samples by removing redundant or noisy examples. Includes random under-sampling, prototype selection, and cleaning techniques.
95
96
```python { .api }
97
class RandomUnderSampler:
98
def __init__(self, sampling_strategy='auto', random_state=None, replacement=False): ...
99
def fit_resample(self, X, y): ...
100
101
class TomekLinks:
102
def __init__(self, sampling_strategy='auto', n_jobs=None): ...
103
def fit_resample(self, X, y): ...
104
105
class EditedNearestNeighbours:
106
def __init__(self, sampling_strategy='auto', n_neighbors=3, kind_sel='all', n_jobs=None): ...
107
def fit_resample(self, X, y): ...
108
```
109
110
[Under-sampling Methods](./under-sampling.md)
111
112
### Combination Methods
113
114
Methods that apply both over-sampling and under-sampling sequentially to balance datasets using complementary techniques.
115
116
```python { .api }
117
class SMOTEENN:
118
def __init__(self, sampling_strategy='auto', random_state=None, smote=None, enn=None, n_jobs=None): ...
119
def fit_resample(self, X, y): ...
120
121
class SMOTETomek:
122
def __init__(self, sampling_strategy='auto', random_state=None, smote=None, tomek=None, n_jobs=None): ...
123
def fit_resample(self, X, y): ...
124
```
125
126
[Combination Methods](./combination.md)
127
128
### Ensemble Methods
129
130
Ensemble classifiers that incorporate sampling techniques during training to handle class imbalance effectively.
131
132
```python { .api }
133
class BalancedBaggingClassifier:
134
def __init__(self, base_estimator=None, n_estimators=10, max_samples=1.0, max_features=1.0,
135
bootstrap=True, bootstrap_features=False, oob_score=False, warm_start=False,
136
sampling_strategy='auto', replacement=False, n_jobs=None, random_state=None, verbose=0): ...
137
def fit(self, X, y, sample_weight=None): ...
138
def predict(self, X): ...
139
140
class BalancedRandomForestClassifier:
141
def __init__(self, n_estimators=100, criterion='gini', max_depth=None, min_samples_split=2,
142
min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features='auto',
143
max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False,
144
sampling_strategy='auto', replacement=False, n_jobs=None, random_state=None,
145
verbose=0, warm_start=False, class_weight=None, ccp_alpha=0.0, max_samples=None): ...
146
def fit(self, X, y, sample_weight=None): ...
147
def predict(self, X): ...
148
```
149
150
[Ensemble Methods](./ensemble.md)
151
152
### Metrics for Imbalanced Datasets
153
154
Specialized metrics designed to evaluate model performance on imbalanced datasets, providing more meaningful insights than traditional accuracy-based measures.
155
156
```python { .api }
157
def sensitivity_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
158
sample_weight=None, zero_division='warn'): ...
159
160
def specificity_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
161
sample_weight=None, zero_division='warn'): ...
162
163
def geometric_mean_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
164
sample_weight=None, correction=0.0, zero_division='warn'): ...
165
166
def classification_report_imbalanced(y_true, y_pred, labels=None, target_names=None,
167
sample_weight=None, digits=2, output_dict=False,
168
zero_division='warn'): ...
169
```
170
171
[Metrics](./metrics.md)
172
173
### Pipeline Integration
174
175
Extended pipeline functionality for combining sampling and classification steps, ensuring proper cross-validation and avoiding data leakage.
176
177
```python { .api }
178
class Pipeline:
179
def __init__(self, steps, memory=None, verbose=False): ...
180
def fit(self, X, y=None, **fit_params): ...
181
def predict(self, X, **predict_params): ...
182
def fit_resample(self, X, y): ...
183
184
def make_pipeline(*steps, memory=None, verbose=False): ...
185
```
186
187
[Pipeline](./pipeline.md)
188
189
### Model Selection
190
191
Cross-validation and model selection tools adapted for imbalanced datasets, including instance hardness-based splitting strategies.
192
193
```python { .api }
194
class InstanceHardnessCV:
195
def __init__(self, estimator, cv=5, n_jobs=None, verbose=0, pre_dispatch='2*n_jobs',
196
scoring=None, return_train_score=False): ...
197
def fit(self, X, y, groups=None, **fit_params): ...
198
def split(self, X, y, groups=None): ...
199
```
200
201
[Model Selection](./model-selection.md)
202
203
### Deep Learning Integration
204
205
Utilities for handling imbalanced datasets in deep learning frameworks, including balanced batch generators for Keras and TensorFlow.
206
207
```python { .api }
208
class BalancedBatchGenerator:
209
def __init__(self, X, y, sampling_strategy='auto', random_state=None, **kwargs): ...
210
def __call__(self): ...
211
212
def balanced_batch_generator(X, y, sampling_strategy='auto', batch_size=32, random_state=None): ...
213
```
214
215
[Deep Learning](./deep-learning.md)
216
217
### Utilities and Validation
218
219
Helper functions for validating sampling strategies, checking neighbor objects, and other utility operations.
220
221
```python { .api }
222
def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): ...
223
224
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): ...
225
226
def check_target_type(y, indicate_one_vs_all=False): ...
227
228
class FunctionSampler:
229
def __init__(self, func=None, accept_sparse=True, kw_args=None, validate=True): ...
230
def fit(self, X, y): ...
231
def fit_resample(self, X, y): ...
232
```
233
234
[Utilities](./utilities.md)
235
236
### Datasets
237
238
Functions for creating imbalanced datasets and fetching benchmark datasets for testing and evaluation.
239
240
```python { .api }
241
def make_imbalance(X, y, sampling_strategy=None, random_state=None, verbose=False, **kwargs): ...
242
243
def fetch_datasets(data_home=None, filter_data=None, download_if_missing=True,
244
return_X_y=False, as_frame=False): ...
245
```
246
247
[Datasets](./datasets.md)
248
249
## Base Classes and Types
250
251
Core base classes and type definitions used throughout the package:
252
253
```python { .api }
254
class BaseSampler:
255
def __init__(self): ...
256
def fit(self, X, y): ...
257
def fit_resample(self, X, y): ...
258
def _validate_params(self): ...
259
260
class SamplerMixin:
261
def fit_resample(self, X, y): ...
262
263
def is_sampler(estimator): ...
264
```