0
# Wrapper Classes
1
2
Core wrapper classes that provide scikit-learn compatibility for Keras models. These classes implement the scikit-learn estimator interface, enabling Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, and cross-validation.
3
4
## Capabilities
5
6
### BaseWrapper
7
8
Abstract base class that implements the core scikit-learn estimator API for Keras models. Provides shared functionality between classification and regression wrappers.
9
10
```python { .api }
11
class BaseWrapper(BaseEstimator):
12
def __init__(
13
self,
14
model=None,
15
*,
16
build_fn=None,
17
warm_start=False,
18
random_state=None,
19
optimizer='rmsprop',
20
loss=None,
21
metrics=None,
22
batch_size=None,
23
validation_batch_size=None,
24
verbose=1,
25
callbacks=None,
26
validation_split=0.0,
27
validation_steps=None,
28
validation_freq=1,
29
shuffle=True,
30
run_eagerly=None,
31
epochs=1,
32
initial_epoch=0,
33
**kwargs
34
):
35
"""
36
Initialize BaseWrapper.
37
38
Args:
39
model: Union[None, Callable[..., keras.Model], keras.Model] - Keras model or callable that returns compiled model
40
build_fn: Union[None, Callable[..., keras.Model], keras.Model] - Deprecated alias for model parameter
41
warm_start: bool - Whether to preserve model parameters between fits
42
random_state: Union[int, np.random.RandomState, None] - Random seed for reproducibility
43
optimizer: Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]] - Optimizer for training
44
loss: Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None] - Loss function
45
metrics: List of metrics to monitor during training
46
batch_size: Union[int, None] - Number of samples per gradient update
47
validation_batch_size: Union[int, None] - Batch size for validation
48
verbose: int - Verbosity level (0=silent, 1=progress bar, 2=one line per epoch)
49
callbacks: List of Keras callbacks
50
validation_split: float - Fraction of training data to use for validation
51
validation_steps: Union[int, None] - Number of steps to draw from validation generator
52
validation_freq: int - Only run validation every N epochs
53
shuffle: bool - Whether to shuffle training data
54
run_eagerly: Union[bool, None] - Whether to run in eager mode
55
epochs: int - Number of training epochs
56
initial_epoch: int - Epoch at which to start training
57
**kwargs: Additional parameters passed to model building function
58
"""
59
60
def fit(self, X, y, *, sample_weight=None, **kwargs):
61
"""
62
Train the Keras model.
63
64
Args:
65
X: array-like of shape (n_samples, n_features) - Training data
66
y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values
67
sample_weight: array-like of shape (n_samples,), optional - Sample weights
68
**kwargs: Additional arguments passed to model.fit()
69
70
Returns:
71
self: Fitted estimator
72
"""
73
74
def partial_fit(self, X, y, *, sample_weight=None, **kwargs):
75
"""
76
Train the model for a single epoch.
77
78
Args:
79
X: array-like of shape (n_samples, n_features) - Training data
80
y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values
81
sample_weight: array-like of shape (n_samples,), optional - Sample weights
82
**kwargs: Additional arguments passed to model.fit()
83
84
Returns:
85
self: Fitted estimator
86
"""
87
88
def predict(self, X, **kwargs):
89
"""
90
Make predictions using the trained model.
91
92
Args:
93
X: array-like of shape (n_samples, n_features) - Input data
94
**kwargs: Additional arguments passed to model.predict()
95
96
Returns:
97
array-like: Predictions
98
"""
99
100
def score(self, X, y, *, sample_weight=None):
101
"""
102
Return the score of the model on the given test data.
103
104
Args:
105
X: array-like of shape (n_samples, n_features) - Test data
106
y: array-like of shape (n_samples,) or (n_samples, n_outputs) - True values
107
sample_weight: array-like of shape (n_samples,), optional - Sample weights
108
109
Returns:
110
float: Model score
111
"""
112
113
def initialize(self, X, y=None):
114
"""
115
Initialize the model without training.
116
117
Args:
118
X: array-like of shape (n_samples, n_features) - Sample data for initialization
119
y: array-like, optional - Sample targets for initialization
120
121
Returns:
122
self: Initialized estimator
123
"""
124
125
@property
126
def current_epoch(self):
127
"""Get current training epoch."""
128
129
@property
130
def initialized_(self):
131
"""Check if model is initialized."""
132
133
@property
134
def target_encoder(self):
135
"""Get target transformation pipeline."""
136
137
@property
138
def feature_encoder(self):
139
"""Get feature transformation pipeline."""
140
141
@property
142
def model_(self):
143
"""Get the instantiated and compiled Keras Model."""
144
145
@property
146
def history_(self):
147
"""Get training history dictionary."""
148
149
@property
150
def n_outputs_expected_(self):
151
"""Get expected number of outputs."""
152
153
@property
154
def target_type_(self):
155
"""Get target type string."""
156
157
@property
158
def classes_(self):
159
"""Get class labels (classification only)."""
160
161
@property
162
def n_classes_(self):
163
"""Get number of classes (classification only)."""
164
165
@property
166
def X_shape_(self):
167
"""Get input data shape from fitting."""
168
169
@property
170
def y_shape_(self):
171
"""Get target data shape from fitting."""
172
173
@property
174
def X_dtype_(self):
175
"""Get input data dtype from fitting."""
176
177
@property
178
def y_dtype_(self):
179
"""Get target data dtype from fitting."""
180
181
@property
182
def n_features_in_(self):
183
"""Get number of features seen during fit."""
184
```
185
186
### KerasClassifier
187
188
Scikit-learn compatible classifier wrapper for Keras models. Supports binary and multiclass classification with probability predictions.
189
190
```python { .api }
191
class KerasClassifier(BaseWrapper, ClassifierMixin):
192
def __init__(self, class_weight=None, **kwargs):
193
"""
194
Initialize KerasClassifier.
195
196
Args:
197
class_weight: dict or 'balanced', optional - Weights for class balancing
198
**kwargs: All arguments from BaseWrapper
199
"""
200
201
def fit(self, X, y, *, sample_weight=None, **kwargs):
202
"""
203
Train the classifier.
204
205
Args:
206
X: array-like of shape (n_samples, n_features) - Training data
207
y: array-like of shape (n_samples,) - Target class labels
208
sample_weight: array-like of shape (n_samples,), optional - Sample weights
209
**kwargs: Additional arguments passed to model.fit()
210
211
Returns:
212
self: Fitted classifier
213
"""
214
215
def partial_fit(self, X, y, *, classes=None, sample_weight=None, **kwargs):
216
"""
217
Train the classifier for a single epoch.
218
219
Args:
220
X: array-like of shape (n_samples, n_features) - Training data
221
y: array-like of shape (n_samples,) - Target class labels
222
classes: array-like of shape (n_classes,), optional - List of all possible classes
223
sample_weight: array-like of shape (n_samples,), optional - Sample weights
224
**kwargs: Additional arguments passed to model.fit()
225
226
Returns:
227
self: Fitted classifier
228
"""
229
230
def predict_proba(self, X, **kwargs):
231
"""
232
Predict class probabilities.
233
234
Args:
235
X: array-like of shape (n_samples, n_features) - Input data
236
**kwargs: Additional arguments passed to model.predict()
237
238
Returns:
239
array-like of shape (n_samples, n_classes): Class probabilities
240
"""
241
242
@property
243
def classes_(self):
244
"""Get class labels."""
245
246
@property
247
def n_classes_(self):
248
"""Get number of classes."""
249
250
```
251
252
### KerasRegressor
253
254
Scikit-learn compatible regressor wrapper for Keras models. Uses R² score as the default scoring metric.
255
256
```python { .api }
257
class KerasRegressor(BaseWrapper, RegressorMixin):
258
def __init__(self, **kwargs):
259
"""
260
Initialize KerasRegressor.
261
262
Args:
263
**kwargs: All arguments from BaseWrapper
264
"""
265
```
266
267
## Usage Examples
268
269
### Basic Classification with Grid Search
270
271
```python
272
from scikeras.wrappers import KerasClassifier
273
from sklearn.model_selection import GridSearchCV
274
import keras
275
276
def create_model(units=50, optimizer='adam'):
277
model = keras.Sequential([
278
keras.layers.Dense(units, activation='relu', input_dim=10),
279
keras.layers.Dense(1, activation='sigmoid')
280
])
281
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
282
return model
283
284
# Create classifier with parameter routing
285
clf = KerasClassifier(
286
model=create_model,
287
epochs=10,
288
batch_size=32,
289
verbose=0
290
)
291
292
# Use with GridSearchCV
293
param_grid = {
294
'model__units': [25, 50, 100],
295
'model__optimizer': ['adam', 'sgd'],
296
'epochs': [5, 10, 15]
297
}
298
299
grid = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy')
300
grid.fit(X_train, y_train)
301
```
302
303
### Warm Start Training
304
305
```python
306
from scikeras.wrappers import KerasRegressor
307
308
# Enable warm start to preserve model weights between fit calls
309
reg = KerasRegressor(
310
model=create_model,
311
epochs=10,
312
warm_start=True
313
)
314
315
# Initial training
316
reg.fit(X_train, y_train)
317
318
# Continue training from previous state
319
reg.set_params(epochs=5) # Train for 5 more epochs
320
reg.fit(X_train, y_train) # Continues from epoch 10
321
```
322
323
### Custom Callbacks
324
325
```python
326
from scikeras.wrappers import KerasClassifier
327
import keras
328
329
# Define custom callbacks
330
early_stopping = keras.callbacks.EarlyStopping(
331
monitor='val_loss', patience=5, restore_best_weights=True
332
)
333
334
reduce_lr = keras.callbacks.ReduceLROnPlateau(
335
monitor='val_loss', factor=0.2, patience=3, min_lr=0.001
336
)
337
338
clf = KerasClassifier(
339
model=create_model,
340
epochs=100,
341
validation_split=0.2,
342
callbacks=[early_stopping, reduce_lr]
343
)
344
345
clf.fit(X_train, y_train)
346
```
347
348
### Parameter Routing
349
350
SciKeras implements a sophisticated parameter routing system that enables passing arguments to nested components using double underscore notation. This allows fine-grained control over all aspects of the model creation, compilation, and training process.
351
352
#### Routing Targets
353
354
Parameters can be routed to different destinations:
355
356
- `model__*`: Parameters passed to the model building function
357
- `compile__*`: Parameters passed to model.compile()
358
- `fit__*`: Parameters passed to model.fit()
359
- `predict__*`: Parameters passed to model.predict()
360
361
#### Examples
362
363
```python
364
# Route parameters to model building function
365
clf = KerasClassifier(model=create_model)
366
clf.set_params(
367
model__units=100, # Passed to create_model(units=100)
368
model__dropout_rate=0.2, # Passed to create_model(dropout_rate=0.2)
369
compile__optimizer='adam', # Passed to model.compile(optimizer='adam')
370
compile__loss='binary_crossentropy', # Passed to model.compile(loss=...)
371
fit__validation_split=0.2, # Passed to fit(validation_split=0.2)
372
fit__callbacks=[early_stop], # Passed to fit(callbacks=...)
373
epochs=50 # Direct parameter to wrapper
374
)
375
```
376
377
#### Nested Routing
378
379
Parameters can be routed to nested objects within the routed target:
380
381
```python
382
# Route to optimizer parameters within compile
383
clf.set_params(
384
compile__optimizer__learning_rate=0.001, # optimizer.learning_rate = 0.001
385
compile__optimizer__beta_1=0.9, # optimizer.beta_1 = 0.9
386
)
387
```
388
389
## Types
390
391
```python { .api }
392
# Model building function signature
393
ModelBuildingFunction = Callable[..., keras.Model]
394
395
# Supported parameter types
396
OptimizerType = Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]]
397
LossType = Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None]
398
MetricsType = Union[List[Union[str, keras.metrics.Metric]], None]
399
CallbacksType = Union[List[keras.callbacks.Callback], None]
400
```