or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

combination.mddatasets.mddeep-learning.mdensemble.mdindex.mdmetrics.mdmodel-selection.mdover-sampling.mdpipeline.mdunder-sampling.mdutilities.md

utilities.mddocs/

0

# Utilities and Validation

1

2

Helper functions and classes for validating sampling strategies, checking neighbor objects, docstring substitution, and creating custom samplers with functional approaches.

3

4

## Overview

5

6

Imbalanced-learn provides comprehensive utility functions that support the core sampling functionality. These utilities handle parameter validation, strategy checking, neighbor object verification, and provide tools for creating custom sampling workflows.

7

8

### Key Features

9

- **Sampling strategy validation**: Robust checking of sampling parameters and strategies

10

- **Neighbor object validation**: Ensures k-NN objects are properly configured

11

- **Target type checking**: Validates target arrays for compatibility with samplers

12

- **Docstring utilities**: Tools for consistent documentation patterns

13

- **Functional sampling**: Create custom samplers from arbitrary functions

14

- **Type detection**: Helper to identify sampler objects

15

16

## Validation Functions

17

18

### check_sampling_strategy

19

20

#### check_sampling_strategy

21

22

```python

23

{ .api }

24

def check_sampling_strategy(

25

sampling_strategy,

26

y,

27

sampling_type,

28

**kwargs

29

) -> dict

30

```

31

32

Sampling target validation for samplers.

33

34

**Parameters:**

35

- **sampling_strategy** (`float`, `str`, `dict`, `list` or `callable`): Sampling information to sample the data set

36

- When `float`: For **under-sampling methods**, it corresponds to the ratio α_us defined by N_rM = α_us × N_m where N_rM and N_m are the number of samples in the majority class after resampling and the number of samples in the minority class, respectively. For **over-sampling methods**, it correspond to the ratio α_os defined by N_rm = α_os × N_m where N_rm and N_M are the number of samples in the minority class after resampling and the number of samples in the majority class, respectively

37

- When `str`: Specify the class targeted by the resampling. Possible choices are: `'minority'`, `'majority'`, `'not minority'`, `'not majority'`, `'all'`, `'auto'`

38

- When `dict`: The keys correspond to the targeted classes. The values correspond to the desired number of samples for each targeted class

39

- When `list`: The list contains the targeted classes. Used only for **cleaning methods**

40

- When `callable`: Function taking `y` and returns a `dict`. The keys correspond to the targeted classes. The values correspond to the desired number of samples for each class

41

- **y** (`ndarray` of shape `(n_samples,)`): The target array

42

- **sampling_type** (`{'over-sampling', 'under-sampling', 'clean-sampling'}`): The type of sampling. Can be either `'over-sampling'`, `'under-sampling'`, or `'clean-sampling'`

43

- **kwargs** (`dict`): Dictionary of additional keyword arguments to pass to `sampling_strategy` when this is a callable

44

45

**Returns:**

46

- **sampling_strategy_converted** (`dict`): The converted and validated sampling target. Returns a dictionary with the key being the class target and the value being the desired number of samples

47

48

**Strategy Types:**

49

50

##### String Strategies

51

```python

52

# Target minority class only (over-sampling)

53

strategy = check_sampling_strategy('minority', y, 'over-sampling')

54

55

# Target majority class only (under-sampling)

56

strategy = check_sampling_strategy('majority', y, 'under-sampling')

57

58

# Target all classes except minority

59

strategy = check_sampling_strategy('not minority', y, 'under-sampling')

60

61

# Target all classes except majority

62

strategy = check_sampling_strategy('not majority', y, 'over-sampling')

63

64

# Target all classes

65

strategy = check_sampling_strategy('all', y, 'over-sampling')

66

67

# Auto strategy (equivalent to 'not majority' for over-sampling, 'not minority' for under-sampling)

68

strategy = check_sampling_strategy('auto', y, 'over-sampling')

69

```

70

71

##### Dictionary Strategies

72

```python

73

from collections import Counter

74

75

# Specify exact number of samples per class

76

y = [0, 0, 0, 1, 1, 2]

77

strategy = {0: 100, 1: 80, 2: 60} # Target samples for each class

78

validated = check_sampling_strategy(strategy, y, 'over-sampling')

79

```

80

81

##### Float Strategies (Binary Only)

82

```python

83

# For binary classification - ratio between classes

84

y_binary = [0, 0, 0, 0, 1] # Imbalanced binary

85

86

# Under-sampling: majority class = 0.5 * minority class size

87

strategy = check_sampling_strategy(0.5, y_binary, 'under-sampling')

88

89

# Over-sampling: minority class = 1.5 * majority class size

90

strategy = check_sampling_strategy(1.5, y_binary, 'over-sampling')

91

```

92

93

##### Callable Strategies

94

```python

95

def custom_strategy(y):

96

"""Custom sampling strategy function."""

97

from collections import Counter

98

counter = Counter(y)

99

# Balance to 80% of majority class size

100

target_size = int(0.8 * max(counter.values()))

101

return {cls: target_size for cls in counter.keys()}

102

103

# Use callable strategy

104

strategy = check_sampling_strategy(custom_strategy, y, 'under-sampling')

105

```

106

107

### check_neighbors_object

108

109

#### check_neighbors_object

110

111

```python

112

{ .api }

113

def check_neighbors_object(

114

nn_name,

115

nn_object,

116

additional_neighbor=0

117

) -> object

118

```

119

120

Check the objects is consistent to be a k nearest neighbors.

121

122

**Parameters:**

123

- **nn_name** (`str`): The name associated to the object to raise an error if needed

124

- **nn_object** (`int` or `KNeighborsMixin`): The object to be checked

125

- **additional_neighbor** (`int`, default=`0`): Sometimes, some algorithm need an additional neighbors

126

127

**Returns:**

128

- **nn_object** (`KNeighborsMixin`): The k-NN object

129

130

**Functionality:**

131

- If `nn_object` is an integer, creates a `NearestNeighbors` object with `n_neighbors=nn_object + additional_neighbor`

132

- If `nn_object` is already a neighbors object, returns a clone of it

133

- Validates that the object has the required k-NN interface

134

135

**Usage Examples:**

136

```python

137

from imblearn.utils import check_neighbors_object

138

from sklearn.neighbors import NearestNeighbors

139

140

# From integer - creates NearestNeighbors(n_neighbors=5)

141

nn = check_neighbors_object('k_neighbors', 5)

142

143

# From existing object - clones it

144

existing_nn = NearestNeighbors(n_neighbors=3, metric='manhattan')

145

nn = check_neighbors_object('k_neighbors', existing_nn)

146

147

# With additional neighbors (for algorithms that need k+1 neighbors)

148

nn = check_neighbors_object('k_neighbors', 5, additional_neighbor=1) # Creates with 6 neighbors

149

```

150

151

### check_target_type

152

153

#### check_target_type

154

155

```python

156

{ .api }

157

def check_target_type(

158

y,

159

indicate_one_vs_all=False

160

) -> ndarray | tuple[ndarray, bool]

161

```

162

163

Check the target types to be conform to the current samplers.

164

165

**Parameters:**

166

- **y** (`ndarray`): The array containing the target

167

- **indicate_one_vs_all** (`bool`, default=`False`): Either to indicate if the targets are encoded in a one-vs-all fashion

168

169

**Returns:**

170

- **y** (`ndarray`): The returned target

171

- **is_one_vs_all** (`bool`, optional): Indicate if the target was originally encoded in a one-vs-all fashion. Only returned if `indicate_one_vs_all=True`

172

173

**Target Type Handling:**

174

- **Binary**: Passes through unchanged

175

- **Multiclass**: Passes through unchanged

176

- **Multilabel-indicator**: Converts to multiclass if it represents one-vs-all encoding (each sample has exactly one label)

177

178

**Example:**

179

```python

180

import numpy as np

181

from imblearn.utils import check_target_type

182

183

# Regular multiclass target

184

y_multiclass = np.array([0, 1, 2, 0, 1, 2])

185

y_checked = check_target_type(y_multiclass)

186

187

# One-vs-all encoded (multilabel-indicator that's actually multiclass)

188

y_ovr = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])

189

y_converted, is_ovr = check_target_type(y_ovr, indicate_one_vs_all=True)

190

# y_converted becomes [0, 1, 2, 0], is_ovr is True

191

192

# True multilabel (not supported - raises error)

193

y_multilabel = np.array([[1, 1, 0], [0, 1, 1], [1, 0, 1]])

194

# check_target_type(y_multilabel) # Raises ValueError

195

```

196

197

## Documentation Utilities

198

199

### Substitution

200

201

#### Substitution

202

203

```python

204

{ .api }

205

class Substitution:

206

def __init__(self, *args, **kwargs): ...

207

def __call__(self, obj): ...

208

```

209

210

Decorate a function's or a class' docstring to perform string substitution on it.

211

212

**Parameters:**

213

- **args** (`tuple`): Positional arguments for substitution (mutually exclusive with kwargs)

214

- **kwargs** (`dict`): Keyword arguments for substitution (mutually exclusive with args)

215

216

**Usage:**

217

The decorator performs string formatting on docstrings using the provided arguments.

218

219

**Example:**

220

```python

221

from imblearn.utils import Substitution

222

223

# Define reusable docstring components

224

_random_state_docstring = """random_state : int, RandomState instance, default=None

225

Control the randomization of the algorithm.

226

227

- If int, random_state is the seed used by the random number generator;

228

- If RandomState instance, random_state is the random number generator;

229

- If None, the random number generator is the RandomState instance used

230

by np.random."""

231

232

# Use as decorator with keyword arguments

233

@Substitution(random_state=_random_state_docstring)

234

def my_function(X, y, random_state=None):

235

"""Apply sampling to dataset.

236

237

Parameters

238

----------

239

X : array-like

240

Input data.

241

y : array-like

242

Target values.

243

{random_state}

244

245

Returns

246

-------

247

X_resampled, y_resampled : arrays

248

Resampled data and targets.

249

"""

250

pass

251

252

# Use with positional arguments

253

@Substitution("This is a substituted description")

254

def another_function():

255

"""{}

256

257

More details here.

258

"""

259

pass

260

```

261

262

## Custom Sampling

263

264

### FunctionSampler

265

266

#### FunctionSampler

267

268

```python

269

{ .api }

270

class FunctionSampler:

271

def __init__(

272

self,

273

*,

274

func=None,

275

accept_sparse=True,

276

kw_args=None,

277

validate=True

278

): ...

279

def fit(self, X, y): ...

280

def fit_resample(self, X, y): ...

281

```

282

283

Construct a sampler from calling an arbitrary callable.

284

285

**Parameters:**

286

- **func** (`callable`, default=`None`): The callable to use for the transformation. This will be passed the same arguments as transform, with args and kwargs forwarded. If func is None, then func will be the identity function

287

- **accept_sparse** (`bool`, default=`True`): Whether sparse input are supported. By default, sparse inputs are supported

288

- **kw_args** (`dict`, default=`None`): The keyword argument expected by `func`

289

- **validate** (`bool`, default=`True`): Whether or not to bypass the validation of `X` and `y`. Turning-off validation allows to use the `FunctionSampler` with any type of data

290

291

**Attributes:**

292

- **sampling_strategy_** (`dict`): Dictionary containing the information to sample the dataset. The keys corresponds to the class labels from which to sample and the values are the number of samples to sample

293

- **n_features_in_** (`int`): Number of features in the input dataset

294

- **feature_names_in_** (`ndarray` of shape `(n_features_in_,)`): Names of features seen during `fit`. Defined only when `X` has feature names that are all strings

295

296

**Methods:**

297

298

##### fit

299

300

```python

301

def fit(self, X, y) -> FunctionSampler

302

```

303

304

Check inputs and statistics of the sampler.

305

306

##### fit_resample

307

308

```python

309

def fit_resample(self, X, y) -> tuple[ndarray, ndarray]

310

```

311

312

Resample the dataset using the provided function.

313

314

**Basic Usage:**

315

```python

316

from imblearn import FunctionSampler

317

import numpy as np

318

319

# Simple function to select first 10 samples

320

def select_first_ten(X, y):

321

return X[:10], y[:10]

322

323

sampler = FunctionSampler(func=select_first_ten)

324

X_res, y_res = sampler.fit_resample(X, y)

325

```

326

327

**Using Existing Samplers:**

328

```python

329

from imblearn import FunctionSampler

330

from imblearn.under_sampling import RandomUnderSampler

331

from collections import Counter

332

333

def custom_undersampling(X, y, sampling_strategy, random_state):

334

"""Custom function using existing sampler."""

335

return RandomUnderSampler(

336

sampling_strategy=sampling_strategy,

337

random_state=random_state

338

).fit_resample(X, y)

339

340

# Create functional sampler

341

sampler = FunctionSampler(

342

func=custom_undersampling,

343

kw_args={

344

'sampling_strategy': 'auto',

345

'random_state': 42

346

}

347

)

348

349

X_res, y_res = sampler.fit_resample(X, y)

350

print(f'Resampled distribution: {Counter(y_res)}')

351

```

352

353

**Advanced Custom Logic:**

354

```python

355

import numpy as np

356

from sklearn.cluster import KMeans

357

358

def cluster_based_sampling(X, y, n_clusters=3, random_state=None):

359

"""Custom sampling based on clustering."""

360

from collections import Counter

361

362

# Get class distribution

363

counter = Counter(y)

364

majority_class = max(counter, key=counter.get)

365

minority_classes = [cls for cls in counter.keys() if cls != majority_class]

366

367

# Keep all minority class samples

368

minority_mask = np.isin(y, minority_classes)

369

X_minority = X[minority_mask]

370

y_minority = y[minority_mask]

371

372

# Cluster majority class and sample from each cluster

373

majority_mask = y == majority_class

374

X_majority = X[majority_mask]

375

376

# Apply clustering

377

kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)

378

clusters = kmeans.fit_predict(X_majority)

379

380

# Sample from each cluster

381

target_per_cluster = len(y_minority) // n_clusters

382

X_sampled_list = []

383

384

for cluster_id in range(n_clusters):

385

cluster_mask = clusters == cluster_id

386

cluster_indices = np.where(cluster_mask)[0]

387

388

if len(cluster_indices) > 0:

389

selected = np.random.choice(

390

cluster_indices,

391

size=min(target_per_cluster, len(cluster_indices)),

392

replace=False

393

)

394

X_sampled_list.append(X_majority[selected])

395

396

# Combine results

397

X_majority_sampled = np.vstack(X_sampled_list)

398

y_majority_sampled = np.full(len(X_majority_sampled), majority_class)

399

400

X_resampled = np.vstack([X_minority, X_majority_sampled])

401

y_resampled = np.concatenate([y_minority, y_majority_sampled])

402

403

return X_resampled, y_resampled

404

405

# Use custom cluster-based sampling

406

sampler = FunctionSampler(

407

func=cluster_based_sampling,

408

kw_args={'n_clusters': 5, 'random_state': 42}

409

)

410

411

X_res, y_res = sampler.fit_resample(X, y)

412

```

413

414

## Type Detection

415

416

### is_sampler

417

418

#### is_sampler

419

420

```python

421

{ .api }

422

def is_sampler(estimator) -> bool

423

```

424

425

Return True if the given estimator is a sampler, False otherwise.

426

427

**Parameters:**

428

- **estimator** (`object`): Estimator to test

429

430

**Returns:**

431

- **is_sampler** (`bool`): True if estimator is a sampler, otherwise False

432

433

**Detection Logic:**

434

1. Checks for `_estimator_type == "sampler"` attribute

435

2. Checks for `sampler_tags` in estimator tags

436

3. Returns False if neither condition is met

437

438

**Example:**

439

```python

440

from imblearn.utils import is_sampler

441

from imblearn.over_sampling import SMOTE

442

from sklearn.ensemble import RandomForestClassifier

443

444

# Test imblearn sampler

445

smote = SMOTE()

446

print(is_sampler(smote)) # True

447

448

# Test sklearn classifier

449

rf = RandomForestClassifier()

450

print(is_sampler(rf)) # False

451

452

# Test custom sampler

453

custom_sampler = FunctionSampler()

454

print(is_sampler(custom_sampler)) # True

455

```

456

457

## Integration Patterns

458

459

### Pipeline Integration

460

461

```python

462

from imblearn.pipeline import Pipeline

463

from imblearn import FunctionSampler

464

from sklearn.ensemble import RandomForestClassifier

465

466

# Create custom sampling function

467

def outlier_removal_sampling(X, y, contamination=0.1):

468

"""Remove outliers before standard sampling."""

469

from sklearn.ensemble import IsolationForest

470

from imblearn.under_sampling import RandomUnderSampler

471

472

# Remove outliers

473

iso_forest = IsolationForest(contamination=contamination, random_state=42)

474

outlier_mask = iso_forest.fit_predict(X) == 1

475

476

X_clean = X[outlier_mask]

477

y_clean = y[outlier_mask]

478

479

# Apply standard sampling

480

sampler = RandomUnderSampler(random_state=42)

481

return sampler.fit_resample(X_clean, y_clean)

482

483

# Use in pipeline

484

pipeline = Pipeline([

485

('outlier_sampling', FunctionSampler(func=outlier_removal_sampling)),

486

('classifier', RandomForestClassifier())

487

])

488

489

pipeline.fit(X, y)

490

predictions = pipeline.predict(X_test)

491

```

492

493

### Cross-Validation Compatibility

494

495

```python

496

from sklearn.model_selection import cross_val_score

497

from imblearn.utils import check_sampling_strategy

498

499

# Validate strategy before cross-validation

500

def safe_sampler_factory(strategy_type='auto'):

501

"""Create sampler with validated strategy."""

502

def create_sampler(X, y):

503

# Validate strategy for current fold

504

strategy = check_sampling_strategy(strategy_type, y, 'over-sampling')

505

506

from imblearn.over_sampling import SMOTE

507

return SMOTE(sampling_strategy=strategy, random_state=42).fit_resample(X, y)

508

509

return FunctionSampler(func=create_sampler)

510

511

# Use in cross-validation

512

sampler = safe_sampler_factory('not majority')

513

pipeline = Pipeline([('sampling', sampler), ('classifier', RandomForestClassifier())])

514

scores = cross_val_score(pipeline, X, y, cv=5)

515

```

516

517

## Best Practices

518

519

### Validation Best Practices

520

521

1. **Always validate sampling strategies** before creating samplers

522

2. **Use check_neighbors_object** for consistent k-NN parameter handling

523

3. **Check target types** early to catch incompatible data formats

524

4. **Validate custom functions** thoroughly before using in FunctionSampler

525

526

### Custom Sampler Guidelines

527

528

1. **Keep functions pure**: Avoid side effects in sampling functions

529

2. **Handle edge cases**: Check for empty classes, insufficient samples

530

3. **Document parameters**: Use clear docstrings and parameter validation

531

4. **Test thoroughly**: Verify behavior with different data distributions

532

5. **Consider performance**: Optimize for large datasets when necessary

533

534

### Error Handling

535

536

```python

537

from imblearn.utils import check_sampling_strategy, check_target_type

538

539

def robust_sampling_pipeline(X, y, sampling_strategy='auto'):

540

"""Example of robust sampling with proper validation."""

541

try:

542

# Validate target type

543

y_validated = check_target_type(y)

544

545

# Validate sampling strategy

546

strategy = check_sampling_strategy(sampling_strategy, y_validated, 'over-sampling')

547

548

# Apply sampling

549

from imblearn.over_sampling import SMOTE

550

sampler = SMOTE(sampling_strategy=strategy)

551

return sampler.fit_resample(X, y_validated)

552

553

except ValueError as e:

554

print(f"Validation error: {e}")

555

# Fallback to identity transformation

556

return X, y

557

except Exception as e:

558

print(f"Sampling error: {e}")

559

return X, y

560

561

# Use robust pipeline

562

X_res, y_res = robust_sampling_pipeline(X, y, 'not majority')

563

```