0
# Utilities and Validation
1
2
Helper functions and classes for validating sampling strategies, checking neighbor objects, docstring substitution, and creating custom samplers with functional approaches.
3
4
## Overview
5
6
Imbalanced-learn provides comprehensive utility functions that support the core sampling functionality. These utilities handle parameter validation, strategy checking, neighbor object verification, and provide tools for creating custom sampling workflows.
7
8
### Key Features
9
- **Sampling strategy validation**: Robust checking of sampling parameters and strategies
10
- **Neighbor object validation**: Ensures k-NN objects are properly configured
11
- **Target type checking**: Validates target arrays for compatibility with samplers
12
- **Docstring utilities**: Tools for consistent documentation patterns
13
- **Functional sampling**: Create custom samplers from arbitrary functions
14
- **Type detection**: Helper to identify sampler objects
15
16
## Validation Functions
17
18
### check_sampling_strategy
19
20
#### check_sampling_strategy
21
22
```python
23
{ .api }
24
def check_sampling_strategy(
25
sampling_strategy,
26
y,
27
sampling_type,
28
**kwargs
29
) -> dict
30
```
31
32
Sampling target validation for samplers.
33
34
**Parameters:**
35
- **sampling_strategy** (`float`, `str`, `dict`, `list` or `callable`): Sampling information to sample the data set
36
- When `float`: For **under-sampling methods**, it corresponds to the ratio α_us defined by N_rM = α_us × N_m where N_rM and N_m are the number of samples in the majority class after resampling and the number of samples in the minority class, respectively. For **over-sampling methods**, it correspond to the ratio α_os defined by N_rm = α_os × N_m where N_rm and N_M are the number of samples in the minority class after resampling and the number of samples in the majority class, respectively
37
- When `str`: Specify the class targeted by the resampling. Possible choices are: `'minority'`, `'majority'`, `'not minority'`, `'not majority'`, `'all'`, `'auto'`
38
- When `dict`: The keys correspond to the targeted classes. The values correspond to the desired number of samples for each targeted class
39
- When `list`: The list contains the targeted classes. Used only for **cleaning methods**
40
- When `callable`: Function taking `y` and returns a `dict`. The keys correspond to the targeted classes. The values correspond to the desired number of samples for each class
41
- **y** (`ndarray` of shape `(n_samples,)`): The target array
42
- **sampling_type** (`{'over-sampling', 'under-sampling', 'clean-sampling'}`): The type of sampling. Can be either `'over-sampling'`, `'under-sampling'`, or `'clean-sampling'`
43
- **kwargs** (`dict`): Dictionary of additional keyword arguments to pass to `sampling_strategy` when this is a callable
44
45
**Returns:**
46
- **sampling_strategy_converted** (`dict`): The converted and validated sampling target. Returns a dictionary with the key being the class target and the value being the desired number of samples
47
48
**Strategy Types:**
49
50
##### String Strategies
51
```python
52
# Target minority class only (over-sampling)
53
strategy = check_sampling_strategy('minority', y, 'over-sampling')
54
55
# Target majority class only (under-sampling)
56
strategy = check_sampling_strategy('majority', y, 'under-sampling')
57
58
# Target all classes except minority
59
strategy = check_sampling_strategy('not minority', y, 'under-sampling')
60
61
# Target all classes except majority
62
strategy = check_sampling_strategy('not majority', y, 'over-sampling')
63
64
# Target all classes
65
strategy = check_sampling_strategy('all', y, 'over-sampling')
66
67
# Auto strategy (equivalent to 'not majority' for over-sampling, 'not minority' for under-sampling)
68
strategy = check_sampling_strategy('auto', y, 'over-sampling')
69
```
70
71
##### Dictionary Strategies
72
```python
73
from collections import Counter
74
75
# Specify exact number of samples per class
76
y = [0, 0, 0, 1, 1, 2]
77
strategy = {0: 100, 1: 80, 2: 60} # Target samples for each class
78
validated = check_sampling_strategy(strategy, y, 'over-sampling')
79
```
80
81
##### Float Strategies (Binary Only)
82
```python
83
# For binary classification - ratio between classes
84
y_binary = [0, 0, 0, 0, 1] # Imbalanced binary
85
86
# Under-sampling: majority class = 0.5 * minority class size
87
strategy = check_sampling_strategy(0.5, y_binary, 'under-sampling')
88
89
# Over-sampling: minority class = 1.5 * majority class size
90
strategy = check_sampling_strategy(1.5, y_binary, 'over-sampling')
91
```
92
93
##### Callable Strategies
94
```python
95
def custom_strategy(y):
96
"""Custom sampling strategy function."""
97
from collections import Counter
98
counter = Counter(y)
99
# Balance to 80% of majority class size
100
target_size = int(0.8 * max(counter.values()))
101
return {cls: target_size for cls in counter.keys()}
102
103
# Use callable strategy
104
strategy = check_sampling_strategy(custom_strategy, y, 'under-sampling')
105
```
106
107
### check_neighbors_object
108
109
#### check_neighbors_object
110
111
```python
112
{ .api }
113
def check_neighbors_object(
114
nn_name,
115
nn_object,
116
additional_neighbor=0
117
) -> object
118
```
119
120
Check the objects is consistent to be a k nearest neighbors.
121
122
**Parameters:**
123
- **nn_name** (`str`): The name associated to the object to raise an error if needed
124
- **nn_object** (`int` or `KNeighborsMixin`): The object to be checked
125
- **additional_neighbor** (`int`, default=`0`): Sometimes, some algorithm need an additional neighbors
126
127
**Returns:**
128
- **nn_object** (`KNeighborsMixin`): The k-NN object
129
130
**Functionality:**
131
- If `nn_object` is an integer, creates a `NearestNeighbors` object with `n_neighbors=nn_object + additional_neighbor`
132
- If `nn_object` is already a neighbors object, returns a clone of it
133
- Validates that the object has the required k-NN interface
134
135
**Usage Examples:**
136
```python
137
from imblearn.utils import check_neighbors_object
138
from sklearn.neighbors import NearestNeighbors
139
140
# From integer - creates NearestNeighbors(n_neighbors=5)
141
nn = check_neighbors_object('k_neighbors', 5)
142
143
# From existing object - clones it
144
existing_nn = NearestNeighbors(n_neighbors=3, metric='manhattan')
145
nn = check_neighbors_object('k_neighbors', existing_nn)
146
147
# With additional neighbors (for algorithms that need k+1 neighbors)
148
nn = check_neighbors_object('k_neighbors', 5, additional_neighbor=1) # Creates with 6 neighbors
149
```
150
151
### check_target_type
152
153
#### check_target_type
154
155
```python
156
{ .api }
157
def check_target_type(
158
y,
159
indicate_one_vs_all=False
160
) -> ndarray | tuple[ndarray, bool]
161
```
162
163
Check the target types to be conform to the current samplers.
164
165
**Parameters:**
166
- **y** (`ndarray`): The array containing the target
167
- **indicate_one_vs_all** (`bool`, default=`False`): Either to indicate if the targets are encoded in a one-vs-all fashion
168
169
**Returns:**
170
- **y** (`ndarray`): The returned target
171
- **is_one_vs_all** (`bool`, optional): Indicate if the target was originally encoded in a one-vs-all fashion. Only returned if `indicate_one_vs_all=True`
172
173
**Target Type Handling:**
174
- **Binary**: Passes through unchanged
175
- **Multiclass**: Passes through unchanged
176
- **Multilabel-indicator**: Converts to multiclass if it represents one-vs-all encoding (each sample has exactly one label)
177
178
**Example:**
179
```python
180
import numpy as np
181
from imblearn.utils import check_target_type
182
183
# Regular multiclass target
184
y_multiclass = np.array([0, 1, 2, 0, 1, 2])
185
y_checked = check_target_type(y_multiclass)
186
187
# One-vs-all encoded (multilabel-indicator that's actually multiclass)
188
y_ovr = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])
189
y_converted, is_ovr = check_target_type(y_ovr, indicate_one_vs_all=True)
190
# y_converted becomes [0, 1, 2, 0], is_ovr is True
191
192
# True multilabel (not supported - raises error)
193
y_multilabel = np.array([[1, 1, 0], [0, 1, 1], [1, 0, 1]])
194
# check_target_type(y_multilabel) # Raises ValueError
195
```
196
197
## Documentation Utilities
198
199
### Substitution
200
201
#### Substitution
202
203
```python
204
{ .api }
205
class Substitution:
206
def __init__(self, *args, **kwargs): ...
207
def __call__(self, obj): ...
208
```
209
210
Decorate a function's or a class' docstring to perform string substitution on it.
211
212
**Parameters:**
213
- **args** (`tuple`): Positional arguments for substitution (mutually exclusive with kwargs)
214
- **kwargs** (`dict`): Keyword arguments for substitution (mutually exclusive with args)
215
216
**Usage:**
217
The decorator performs string formatting on docstrings using the provided arguments.
218
219
**Example:**
220
```python
221
from imblearn.utils import Substitution
222
223
# Define reusable docstring components
224
_random_state_docstring = """random_state : int, RandomState instance, default=None
225
Control the randomization of the algorithm.
226
227
- If int, random_state is the seed used by the random number generator;
228
- If RandomState instance, random_state is the random number generator;
229
- If None, the random number generator is the RandomState instance used
230
by np.random."""
231
232
# Use as decorator with keyword arguments
233
@Substitution(random_state=_random_state_docstring)
234
def my_function(X, y, random_state=None):
235
"""Apply sampling to dataset.
236
237
Parameters
238
----------
239
X : array-like
240
Input data.
241
y : array-like
242
Target values.
243
{random_state}
244
245
Returns
246
-------
247
X_resampled, y_resampled : arrays
248
Resampled data and targets.
249
"""
250
pass
251
252
# Use with positional arguments
253
@Substitution("This is a substituted description")
254
def another_function():
255
"""{}
256
257
More details here.
258
"""
259
pass
260
```
261
262
## Custom Sampling
263
264
### FunctionSampler
265
266
#### FunctionSampler
267
268
```python
269
{ .api }
270
class FunctionSampler:
271
def __init__(
272
self,
273
*,
274
func=None,
275
accept_sparse=True,
276
kw_args=None,
277
validate=True
278
): ...
279
def fit(self, X, y): ...
280
def fit_resample(self, X, y): ...
281
```
282
283
Construct a sampler from calling an arbitrary callable.
284
285
**Parameters:**
286
- **func** (`callable`, default=`None`): The callable to use for the transformation. This will be passed the same arguments as transform, with args and kwargs forwarded. If func is None, then func will be the identity function
287
- **accept_sparse** (`bool`, default=`True`): Whether sparse input are supported. By default, sparse inputs are supported
288
- **kw_args** (`dict`, default=`None`): The keyword argument expected by `func`
289
- **validate** (`bool`, default=`True`): Whether or not to bypass the validation of `X` and `y`. Turning-off validation allows to use the `FunctionSampler` with any type of data
290
291
**Attributes:**
292
- **sampling_strategy_** (`dict`): Dictionary containing the information to sample the dataset. The keys corresponds to the class labels from which to sample and the values are the number of samples to sample
293
- **n_features_in_** (`int`): Number of features in the input dataset
294
- **feature_names_in_** (`ndarray` of shape `(n_features_in_,)`): Names of features seen during `fit`. Defined only when `X` has feature names that are all strings
295
296
**Methods:**
297
298
##### fit
299
300
```python
301
def fit(self, X, y) -> FunctionSampler
302
```
303
304
Check inputs and statistics of the sampler.
305
306
##### fit_resample
307
308
```python
309
def fit_resample(self, X, y) -> tuple[ndarray, ndarray]
310
```
311
312
Resample the dataset using the provided function.
313
314
**Basic Usage:**
315
```python
316
from imblearn import FunctionSampler
317
import numpy as np
318
319
# Simple function to select first 10 samples
320
def select_first_ten(X, y):
321
return X[:10], y[:10]
322
323
sampler = FunctionSampler(func=select_first_ten)
324
X_res, y_res = sampler.fit_resample(X, y)
325
```
326
327
**Using Existing Samplers:**
328
```python
329
from imblearn import FunctionSampler
330
from imblearn.under_sampling import RandomUnderSampler
331
from collections import Counter
332
333
def custom_undersampling(X, y, sampling_strategy, random_state):
334
"""Custom function using existing sampler."""
335
return RandomUnderSampler(
336
sampling_strategy=sampling_strategy,
337
random_state=random_state
338
).fit_resample(X, y)
339
340
# Create functional sampler
341
sampler = FunctionSampler(
342
func=custom_undersampling,
343
kw_args={
344
'sampling_strategy': 'auto',
345
'random_state': 42
346
}
347
)
348
349
X_res, y_res = sampler.fit_resample(X, y)
350
print(f'Resampled distribution: {Counter(y_res)}')
351
```
352
353
**Advanced Custom Logic:**
354
```python
355
import numpy as np
356
from sklearn.cluster import KMeans
357
358
def cluster_based_sampling(X, y, n_clusters=3, random_state=None):
359
"""Custom sampling based on clustering."""
360
from collections import Counter
361
362
# Get class distribution
363
counter = Counter(y)
364
majority_class = max(counter, key=counter.get)
365
minority_classes = [cls for cls in counter.keys() if cls != majority_class]
366
367
# Keep all minority class samples
368
minority_mask = np.isin(y, minority_classes)
369
X_minority = X[minority_mask]
370
y_minority = y[minority_mask]
371
372
# Cluster majority class and sample from each cluster
373
majority_mask = y == majority_class
374
X_majority = X[majority_mask]
375
376
# Apply clustering
377
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
378
clusters = kmeans.fit_predict(X_majority)
379
380
# Sample from each cluster
381
target_per_cluster = len(y_minority) // n_clusters
382
X_sampled_list = []
383
384
for cluster_id in range(n_clusters):
385
cluster_mask = clusters == cluster_id
386
cluster_indices = np.where(cluster_mask)[0]
387
388
if len(cluster_indices) > 0:
389
selected = np.random.choice(
390
cluster_indices,
391
size=min(target_per_cluster, len(cluster_indices)),
392
replace=False
393
)
394
X_sampled_list.append(X_majority[selected])
395
396
# Combine results
397
X_majority_sampled = np.vstack(X_sampled_list)
398
y_majority_sampled = np.full(len(X_majority_sampled), majority_class)
399
400
X_resampled = np.vstack([X_minority, X_majority_sampled])
401
y_resampled = np.concatenate([y_minority, y_majority_sampled])
402
403
return X_resampled, y_resampled
404
405
# Use custom cluster-based sampling
406
sampler = FunctionSampler(
407
func=cluster_based_sampling,
408
kw_args={'n_clusters': 5, 'random_state': 42}
409
)
410
411
X_res, y_res = sampler.fit_resample(X, y)
412
```
413
414
## Type Detection
415
416
### is_sampler
417
418
#### is_sampler
419
420
```python
421
{ .api }
422
def is_sampler(estimator) -> bool
423
```
424
425
Return True if the given estimator is a sampler, False otherwise.
426
427
**Parameters:**
428
- **estimator** (`object`): Estimator to test
429
430
**Returns:**
431
- **is_sampler** (`bool`): True if estimator is a sampler, otherwise False
432
433
**Detection Logic:**
434
1. Checks for `_estimator_type == "sampler"` attribute
435
2. Checks for `sampler_tags` in estimator tags
436
3. Returns False if neither condition is met
437
438
**Example:**
439
```python
440
from imblearn.utils import is_sampler
441
from imblearn.over_sampling import SMOTE
442
from sklearn.ensemble import RandomForestClassifier
443
444
# Test imblearn sampler
445
smote = SMOTE()
446
print(is_sampler(smote)) # True
447
448
# Test sklearn classifier
449
rf = RandomForestClassifier()
450
print(is_sampler(rf)) # False
451
452
# Test custom sampler
453
custom_sampler = FunctionSampler()
454
print(is_sampler(custom_sampler)) # True
455
```
456
457
## Integration Patterns
458
459
### Pipeline Integration
460
461
```python
462
from imblearn.pipeline import Pipeline
463
from imblearn import FunctionSampler
464
from sklearn.ensemble import RandomForestClassifier
465
466
# Create custom sampling function
467
def outlier_removal_sampling(X, y, contamination=0.1):
468
"""Remove outliers before standard sampling."""
469
from sklearn.ensemble import IsolationForest
470
from imblearn.under_sampling import RandomUnderSampler
471
472
# Remove outliers
473
iso_forest = IsolationForest(contamination=contamination, random_state=42)
474
outlier_mask = iso_forest.fit_predict(X) == 1
475
476
X_clean = X[outlier_mask]
477
y_clean = y[outlier_mask]
478
479
# Apply standard sampling
480
sampler = RandomUnderSampler(random_state=42)
481
return sampler.fit_resample(X_clean, y_clean)
482
483
# Use in pipeline
484
pipeline = Pipeline([
485
('outlier_sampling', FunctionSampler(func=outlier_removal_sampling)),
486
('classifier', RandomForestClassifier())
487
])
488
489
pipeline.fit(X, y)
490
predictions = pipeline.predict(X_test)
491
```
492
493
### Cross-Validation Compatibility
494
495
```python
496
from sklearn.model_selection import cross_val_score
497
from imblearn.utils import check_sampling_strategy
498
499
# Validate strategy before cross-validation
500
def safe_sampler_factory(strategy_type='auto'):
501
"""Create sampler with validated strategy."""
502
def create_sampler(X, y):
503
# Validate strategy for current fold
504
strategy = check_sampling_strategy(strategy_type, y, 'over-sampling')
505
506
from imblearn.over_sampling import SMOTE
507
return SMOTE(sampling_strategy=strategy, random_state=42).fit_resample(X, y)
508
509
return FunctionSampler(func=create_sampler)
510
511
# Use in cross-validation
512
sampler = safe_sampler_factory('not majority')
513
pipeline = Pipeline([('sampling', sampler), ('classifier', RandomForestClassifier())])
514
scores = cross_val_score(pipeline, X, y, cv=5)
515
```
516
517
## Best Practices
518
519
### Validation Best Practices
520
521
1. **Always validate sampling strategies** before creating samplers
522
2. **Use check_neighbors_object** for consistent k-NN parameter handling
523
3. **Check target types** early to catch incompatible data formats
524
4. **Validate custom functions** thoroughly before using in FunctionSampler
525
526
### Custom Sampler Guidelines
527
528
1. **Keep functions pure**: Avoid side effects in sampling functions
529
2. **Handle edge cases**: Check for empty classes, insufficient samples
530
3. **Document parameters**: Use clear docstrings and parameter validation
531
4. **Test thoroughly**: Verify behavior with different data distributions
532
5. **Consider performance**: Optimize for large datasets when necessary
533
534
### Error Handling
535
536
```python
537
from imblearn.utils import check_sampling_strategy, check_target_type
538
539
def robust_sampling_pipeline(X, y, sampling_strategy='auto'):
540
"""Example of robust sampling with proper validation."""
541
try:
542
# Validate target type
543
y_validated = check_target_type(y)
544
545
# Validate sampling strategy
546
strategy = check_sampling_strategy(sampling_strategy, y_validated, 'over-sampling')
547
548
# Apply sampling
549
from imblearn.over_sampling import SMOTE
550
sampler = SMOTE(sampling_strategy=strategy)
551
return sampler.fit_resample(X, y_validated)
552
553
except ValueError as e:
554
print(f"Validation error: {e}")
555
# Fallback to identity transformation
556
return X, y
557
except Exception as e:
558
print(f"Sampling error: {e}")
559
return X, y
560
561
# Use robust pipeline
562
X_res, y_res = robust_sampling_pipeline(X, y, 'not majority')
563
```