0
# Model Selection and Evaluation
1
2
This document covers all model selection, cross-validation, hyperparameter tuning, and evaluation capabilities in scikit-learn.
3
4
## Cross-Validation
5
6
### Cross-Validation Iterators
7
8
#### KFold { .api }
9
```python
10
from sklearn.model_selection import KFold
11
12
KFold(
13
n_splits: int = 5,
14
shuffle: bool = False,
15
random_state: int | RandomState | None = None
16
)
17
```
18
K-Folds cross-validator.
19
20
#### StratifiedKFold { .api }
21
```python
22
from sklearn.model_selection import StratifiedKFold
23
24
StratifiedKFold(
25
n_splits: int = 5,
26
shuffle: bool = False,
27
random_state: int | RandomState | None = None
28
)
29
```
30
Stratified K-Folds cross-validator.
31
32
#### GroupKFold { .api }
33
```python
34
from sklearn.model_selection import GroupKFold
35
36
GroupKFold(
37
n_splits: int = 5
38
)
39
```
40
K-fold iterator variant with non-overlapping groups.
41
42
#### StratifiedGroupKFold { .api }
43
```python
44
from sklearn.model_selection import StratifiedGroupKFold
45
46
StratifiedGroupKFold(
47
n_splits: int = 5,
48
shuffle: bool = False,
49
random_state: int | RandomState | None = None
50
)
51
```
52
Stratified K-Folds iterator variant with non-overlapping groups.
53
54
#### TimeSeriesSplit { .api }
55
```python
56
from sklearn.model_selection import TimeSeriesSplit
57
58
TimeSeriesSplit(
59
n_splits: int = 5,
60
max_train_size: int | None = None,
61
test_size: int | None = None,
62
gap: int = 0
63
)
64
```
65
Time Series cross-validator.
66
67
#### LeaveOneOut { .api }
68
```python
69
from sklearn.model_selection import LeaveOneOut
70
71
LeaveOneOut()
72
```
73
Leave-One-Out cross-validator.
74
75
#### LeavePOut { .api }
76
```python
77
from sklearn.model_selection import LeavePOut
78
79
LeavePOut(
80
p: int
81
)
82
```
83
Leave-P-Out cross-validator.
84
85
#### LeaveOneGroupOut { .api }
86
```python
87
from sklearn.model_selection import LeaveOneGroupOut
88
89
LeaveOneGroupOut()
90
```
91
Leave One Group Out cross-validator.
92
93
#### LeavePGroupsOut { .api }
94
```python
95
from sklearn.model_selection import LeavePGroupsOut
96
97
LeavePGroupsOut(
98
n_groups: int
99
)
100
```
101
Leave P Group(s) Out cross-validator.
102
103
#### ShuffleSplit { .api }
104
```python
105
from sklearn.model_selection import ShuffleSplit
106
107
ShuffleSplit(
108
n_splits: int = 10,
109
test_size: float | int | None = None,
110
train_size: float | int | None = None,
111
random_state: int | RandomState | None = None
112
)
113
```
114
Random permutation cross-validator.
115
116
#### StratifiedShuffleSplit { .api }
117
```python
118
from sklearn.model_selection import StratifiedShuffleSplit
119
120
StratifiedShuffleSplit(
121
n_splits: int = 10,
122
test_size: float | int | None = None,
123
train_size: float | int | None = None,
124
random_state: int | RandomState | None = None
125
)
126
```
127
Stratified ShuffleSplit cross-validator.
128
129
#### GroupShuffleSplit { .api }
130
```python
131
from sklearn.model_selection import GroupShuffleSplit
132
133
GroupShuffleSplit(
134
n_splits: int = 5,
135
test_size: float | int | None = None,
136
train_size: float | int | None = None,
137
random_state: int | RandomState | None = None
138
)
139
```
140
Shuffle-Group(s)-Out cross-validation iterator.
141
142
#### PredefinedSplit { .api }
143
```python
144
from sklearn.model_selection import PredefinedSplit
145
146
PredefinedSplit(
147
test_fold: ArrayLike
148
)
149
```
150
Predefined split cross-validator.
151
152
#### RepeatedKFold { .api }
153
```python
154
from sklearn.model_selection import RepeatedKFold
155
156
RepeatedKFold(
157
n_splits: int = 5,
158
n_repeats: int = 10,
159
random_state: int | RandomState | None = None
160
)
161
```
162
Repeated K-Fold cross validator.
163
164
#### RepeatedStratifiedKFold { .api }
165
```python
166
from sklearn.model_selection import RepeatedStratifiedKFold
167
168
RepeatedStratifiedKFold(
169
n_splits: int = 5,
170
n_repeats: int = 10,
171
random_state: int | RandomState | None = None
172
)
173
```
174
Repeated Stratified K-Fold cross validator.
175
176
### Base Cross-Validation Classes
177
178
#### BaseCrossValidator { .api }
179
```python
180
from sklearn.model_selection import BaseCrossValidator
181
182
BaseCrossValidator()
183
```
184
Base class for all cross-validators.
185
186
#### BaseShuffleSplit { .api }
187
```python
188
from sklearn.model_selection import BaseShuffleSplit
189
190
BaseShuffleSplit(
191
n_splits: int = 10,
192
test_size: float | int | None = None,
193
train_size: float | int | None = None,
194
random_state: int | RandomState | None = None
195
)
196
```
197
Base class for ShuffleSplit cross-validator.
198
199
## Hyperparameter Tuning
200
201
### Grid Search
202
203
#### GridSearchCV { .api }
204
```python
205
from sklearn.model_selection import GridSearchCV
206
207
GridSearchCV(
208
estimator: BaseEstimator,
209
param_grid: dict | list[dict],
210
scoring: str | Callable | list | tuple | dict | None = None,
211
n_jobs: int | None = None,
212
refit: bool | str | Callable = True,
213
cv: int | BaseCrossValidator | Iterable | None = None,
214
verbose: int = 0,
215
pre_dispatch: int | str = "2*n_jobs",
216
error_score: float | str = ...,
217
return_train_score: bool = False
218
)
219
```
220
Exhaustive search over specified parameter values for an estimator.
221
222
### Randomized Search
223
224
#### RandomizedSearchCV { .api }
225
```python
226
from sklearn.model_selection import RandomizedSearchCV
227
228
RandomizedSearchCV(
229
estimator: BaseEstimator,
230
param_distributions: dict | list[dict],
231
n_iter: int = 10,
232
scoring: str | Callable | list | tuple | dict | None = None,
233
n_jobs: int | None = None,
234
refit: bool | str | Callable = True,
235
cv: int | BaseCrossValidator | Iterable | None = None,
236
verbose: int = 0,
237
pre_dispatch: int | str = "2*n_jobs",
238
random_state: int | RandomState | None = None,
239
error_score: float | str = ...,
240
return_train_score: bool = False
241
)
242
```
243
Randomized search on hyper parameters.
244
245
### Parameter Generation
246
247
#### ParameterGrid { .api }
248
```python
249
from sklearn.model_selection import ParameterGrid
250
251
ParameterGrid(
252
param_grid: dict | list[dict]
253
)
254
```
255
Grid of parameters with a discrete number of values for each.
256
257
#### ParameterSampler { .api }
258
```python
259
from sklearn.model_selection import ParameterSampler
260
261
ParameterSampler(
262
param_distributions: dict,
263
n_iter: int,
264
random_state: int | RandomState | None = None
265
)
266
```
267
Generator on parameters sampled from given distributions.
268
269
## Threshold Optimization
270
271
#### TunedThresholdClassifierCV { .api }
272
```python
273
from sklearn.model_selection import TunedThresholdClassifierCV
274
275
TunedThresholdClassifierCV(
276
estimator: BaseClassifier,
277
scoring: str | Callable = "balanced_accuracy",
278
response_method: str = "auto",
279
thresholds: int | ArrayLike = 100,
280
cv: int | BaseCrossValidator | Iterable | None = None,
281
refit: bool = True,
282
n_jobs: int | None = None,
283
verbose: int = 0,
284
random_state: int | RandomState | None = None,
285
store_cv_results: bool = False
286
)
287
```
288
Classifier that post-tunes the decision threshold using cross-validation.
289
290
#### FixedThresholdClassifier { .api }
291
```python
292
from sklearn.model_selection import FixedThresholdClassifier
293
294
FixedThresholdClassifier(
295
estimator: BaseClassifier,
296
threshold: float | str = 0.5,
297
response_method: str = "auto"
298
)
299
```
300
Binary classifier that manually sets the decision threshold.
301
302
## Cross-Validation Functions
303
304
### Basic Cross-Validation
305
306
#### cross_val_score { .api }
307
```python
308
from sklearn.model_selection import cross_val_score
309
310
cross_val_score(
311
estimator: BaseEstimator,
312
X: ArrayLike,
313
y: ArrayLike | None = None,
314
groups: ArrayLike | None = None,
315
scoring: str | Callable | None = None,
316
cv: int | BaseCrossValidator | Iterable | None = None,
317
n_jobs: int | None = None,
318
verbose: int = 0,
319
fit_params: dict | None = None,
320
pre_dispatch: int | str = "2*n_jobs",
321
error_score: float | str = ...,
322
params: dict | None = None
323
) -> ArrayLike
324
```
325
Evaluate a score by cross-validation.
326
327
#### cross_validate { .api }
328
```python
329
from sklearn.model_selection import cross_validate
330
331
cross_validate(
332
estimator: BaseEstimator,
333
X: ArrayLike,
334
y: ArrayLike | None = None,
335
groups: ArrayLike | None = None,
336
scoring: str | Callable | list | tuple | dict | None = None,
337
cv: int | BaseCrossValidator | Iterable | None = None,
338
n_jobs: int | None = None,
339
verbose: int = 0,
340
fit_params: dict | None = None,
341
pre_dispatch: int | str = "2*n_jobs",
342
return_train_score: bool = False,
343
return_estimator: bool = False,
344
return_indices: bool = False,
345
error_score: float | str = ...,
346
params: dict | None = None
347
) -> dict[str, ArrayLike]
348
```
349
Evaluate metric(s) by cross-validation and also record fit/score times.
350
351
#### cross_val_predict { .api }
352
```python
353
from sklearn.model_selection import cross_val_predict
354
355
cross_val_predict(
356
estimator: BaseEstimator,
357
X: ArrayLike,
358
y: ArrayLike | None = None,
359
groups: ArrayLike | None = None,
360
cv: int | BaseCrossValidator | Iterable | None = None,
361
n_jobs: int | None = None,
362
verbose: int = 0,
363
fit_params: dict | None = None,
364
pre_dispatch: int | str = "2*n_jobs",
365
method: str = "predict",
366
params: dict | None = None
367
) -> ArrayLike
368
```
369
Generate cross-validated estimates for each input data point.
370
371
### Data Splitting
372
373
#### train_test_split { .api }
374
```python
375
from sklearn.model_selection import train_test_split
376
377
train_test_split(
378
*arrays: ArrayLike,
379
test_size: float | int | None = None,
380
train_size: float | int | None = None,
381
random_state: int | RandomState | None = None,
382
shuffle: bool = True,
383
stratify: ArrayLike | None = None
384
) -> list[ArrayLike]
385
```
386
Split arrays or matrices into random train and test subsets.
387
388
### Validation Curves
389
390
#### validation_curve { .api }
391
```python
392
from sklearn.model_selection import validation_curve
393
394
validation_curve(
395
estimator: BaseEstimator,
396
X: ArrayLike,
397
y: ArrayLike,
398
param_name: str,
399
param_range: ArrayLike,
400
groups: ArrayLike | None = None,
401
cv: int | BaseCrossValidator | Iterable | None = None,
402
scoring: str | Callable | None = None,
403
n_jobs: int | None = None,
404
pre_dispatch: int | str = "all",
405
verbose: int = 0,
406
error_score: float | str = ...,
407
fit_params: dict | None = None,
408
params: dict | None = None
409
) -> tuple[ArrayLike, ArrayLike]
410
```
411
Validation curve.
412
413
#### learning_curve { .api }
414
```python
415
from sklearn.model_selection import learning_curve
416
417
learning_curve(
418
estimator: BaseEstimator,
419
X: ArrayLike,
420
y: ArrayLike,
421
groups: ArrayLike | None = None,
422
train_sizes: ArrayLike = ...,
423
cv: int | BaseCrossValidator | Iterable | None = None,
424
scoring: str | Callable | None = None,
425
exploit_incremental_learning: bool = False,
426
n_jobs: int | None = None,
427
pre_dispatch: int | str = "all",
428
verbose: int = 0,
429
shuffle: bool = False,
430
random_state: int | RandomState | None = None,
431
error_score: float | str = ...,
432
return_times: bool = False,
433
fit_params: dict | None = None,
434
params: dict | None = None
435
) -> tuple[ArrayLike, ArrayLike, ArrayLike] | tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike, ArrayLike]
436
```
437
Learning curve.
438
439
### Permutation Testing
440
441
#### permutation_test_score { .api }
442
```python
443
from sklearn.model_selection import permutation_test_score
444
445
permutation_test_score(
446
estimator: BaseEstimator,
447
X: ArrayLike,
448
y: ArrayLike,
449
groups: ArrayLike | None = None,
450
cv: int | BaseCrossValidator | Iterable | None = None,
451
n_permutations: int = 100,
452
n_jobs: int | None = None,
453
random_state: int | RandomState | None = None,
454
verbose: int = 0,
455
scoring: str | Callable | None = None,
456
fit_params: dict | None = None
457
) -> tuple[float, ArrayLike, float]
458
```
459
Evaluate the significance of a cross-validated score with permutations.
460
461
### Utility Functions
462
463
#### check_cv { .api }
464
```python
465
from sklearn.model_selection import check_cv
466
467
check_cv(
468
cv: int | BaseCrossValidator | Iterable | None = 5,
469
y: ArrayLike | None = None,
470
classifier: bool = False
471
) -> BaseCrossValidator
472
```
473
Input checker utility for building a cross-validator.
474
475
## Visualization Classes
476
477
### Learning Curves
478
479
#### LearningCurveDisplay { .api }
480
```python
481
from sklearn.model_selection import LearningCurveDisplay
482
483
LearningCurveDisplay(
484
train_sizes: ArrayLike,
485
train_scores: ArrayLike,
486
test_scores: ArrayLike,
487
train_scores_std: ArrayLike | None = None,
488
test_scores_std: ArrayLike | None = None
489
)
490
```
491
Learning Curve visualization.
492
493
### Validation Curves
494
495
#### ValidationCurveDisplay { .api }
496
```python
497
from sklearn.model_selection import ValidationCurveDisplay
498
499
ValidationCurveDisplay(
500
param_name: str,
501
param_range: ArrayLike,
502
train_scores: ArrayLike,
503
test_scores: ArrayLike,
504
train_scores_std: ArrayLike | None = None,
505
test_scores_std: ArrayLike | None = None
506
)
507
```
508
Validation Curve visualization.
509
510
## Calibration
511
512
### Probability Calibration
513
514
#### CalibratedClassifierCV { .api }
515
```python
516
from sklearn.calibration import CalibratedClassifierCV
517
518
CalibratedClassifierCV(
519
estimator: BaseClassifier | None = None,
520
method: str = "sigmoid",
521
cv: int | BaseCrossValidator | Iterable | str | None = None,
522
n_jobs: int | None = None,
523
ensemble: bool = True,
524
base_estimator: BaseClassifier = "deprecated"
525
)
526
```
527
Probability calibration with isotonic regression or logistic regression.
528
529
### Calibration Functions
530
531
#### calibration_curve { .api }
532
```python
533
from sklearn.calibration import calibration_curve
534
535
calibration_curve(
536
y_true: ArrayLike,
537
y_prob: ArrayLike,
538
pos_label: int | str | None = None,
539
normalize: bool = "deprecated",
540
n_bins: int = 5,
541
strategy: str = "uniform"
542
) -> tuple[ArrayLike, ArrayLike]
543
```
544
Compute true and predicted probabilities for a calibration curve.
545
546
### Calibration Display
547
548
#### CalibrationDisplay { .api }
549
```python
550
from sklearn.calibration import CalibrationDisplay
551
552
CalibrationDisplay(
553
prob_true: ArrayLike,
554
prob_pred: ArrayLike,
555
y_prob: ArrayLike,
556
estimator_name: str | None = None,
557
pos_label: int | str | None = None
558
)
559
```
560
Calibration curve visualization.
561
562
## Examples
563
564
### Basic Cross-Validation Example
565
566
```python
567
from sklearn.model_selection import cross_val_score, KFold
568
from sklearn.ensemble import RandomForestClassifier
569
from sklearn.datasets import load_iris
570
571
# Load data
572
X, y = load_iris(return_X_y=True)
573
574
# Create model
575
model = RandomForestClassifier(n_estimators=100, random_state=42)
576
577
# Cross-validation with different strategies
578
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
579
scores = cross_val_score(model, X, y, cv=kfold, scoring='accuracy')
580
581
print(f"CV Accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")
582
```
583
584
### Hyperparameter Tuning Example
585
586
```python
587
from sklearn.model_selection import GridSearchCV
588
from sklearn.ensemble import RandomForestClassifier
589
from sklearn.datasets import load_digits
590
591
# Load data
592
X, y = load_digits(return_X_y=True)
593
594
# Define parameter grid
595
param_grid = {
596
'n_estimators': [50, 100, 200],
597
'max_depth': [3, 5, 7, None],
598
'min_samples_split': [2, 5, 10]
599
}
600
601
# Grid search
602
grid_search = GridSearchCV(
603
estimator=RandomForestClassifier(random_state=42),
604
param_grid=param_grid,
605
cv=5,
606
scoring='accuracy',
607
n_jobs=-1,
608
verbose=1
609
)
610
611
# Fit and get results
612
grid_search.fit(X, y)
613
print(f"Best parameters: {grid_search.best_params_}")
614
print(f"Best cross-validation score: {grid_search.best_score_:.3f}")
615
```
616
617
### Learning Curve Example
618
619
```python
620
from sklearn.model_selection import learning_curve, LearningCurveDisplay
621
from sklearn.ensemble import RandomForestClassifier
622
import matplotlib.pyplot as plt
623
624
# Generate learning curve
625
train_sizes, train_scores, test_scores = learning_curve(
626
RandomForestClassifier(n_estimators=100, random_state=42),
627
X, y, cv=5, n_jobs=-1,
628
train_sizes=np.linspace(0.1, 1.0, 10)
629
)
630
631
# Plot learning curve
632
display = LearningCurveDisplay(
633
train_sizes=train_sizes,
634
train_scores=train_scores,
635
test_scores=test_scores
636
)
637
display.plot()
638
plt.show()
639
```
640
641
### Validation Curve Example
642
643
```python
644
from sklearn.model_selection import validation_curve, ValidationCurveDisplay
645
646
# Generate validation curve for max_depth parameter
647
param_range = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
648
train_scores, test_scores = validation_curve(
649
RandomForestClassifier(n_estimators=100, random_state=42),
650
X, y, param_name='max_depth', param_range=param_range,
651
cv=5, scoring='accuracy', n_jobs=-1
652
)
653
654
# Plot validation curve
655
display = ValidationCurveDisplay(
656
param_name='max_depth',
657
param_range=param_range,
658
train_scores=train_scores,
659
test_scores=test_scores
660
)
661
display.plot()
662
plt.show()
663
```
664
665
### Calibration Example
666
667
```python
668
from sklearn.calibration import CalibratedClassifierCV, calibration_curve
669
from sklearn.ensemble import RandomForestClassifier
670
from sklearn.model_selection import train_test_split
671
672
# Split data
673
X_train, X_test, y_train, y_test = train_test_split(
674
X, y, test_size=0.2, random_state=42
675
)
676
677
# Train uncalibrated classifier
678
clf = RandomForestClassifier(n_estimators=100, random_state=42)
679
clf.fit(X_train, y_train)
680
681
# Calibrate classifier
682
calibrated_clf = CalibratedClassifierCV(clf, method='isotonic', cv=3)
683
calibrated_clf.fit(X_train, y_train)
684
685
# Get calibrated probabilities
686
y_prob = calibrated_clf.predict_proba(X_test)[:, 1]
687
688
# Evaluate calibration
689
fraction_of_positives, mean_predicted_value = calibration_curve(
690
y_test, y_prob, n_bins=10
691
)
692
```