Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.
npx @tessl/cli install tessl/pypi-scikeras@0.13.00
# SciKeras
1
2
A Python library providing Scikit-Learn compatible wrappers for Keras models, enabling seamless integration of deep learning models into scikit-learn workflows. SciKeras serves as a modern replacement for the deprecated `tf.keras.wrappers.scikit_learn`.
3
4
## Package Information
5
6
- **Package Name**: scikeras
7
- **Package Type**: pypi
8
- **Language**: Python
9
- **Installation**: `pip install scikeras[tensorflow]` or `pip install scikeras` (requires separate TensorFlow installation)
10
11
## Core Imports
12
13
```python
14
from scikeras.wrappers import KerasClassifier, KerasRegressor
15
```
16
17
For advanced usage:
18
19
```python
20
from scikeras.wrappers import BaseWrapper
21
from scikeras.utils import loss_name, metric_name
22
from scikeras.utils.transformers import ClassifierLabelEncoder, TargetReshaper
23
from scikeras.utils.random_state import tensorflow_random_state
24
```
25
26
## Basic Usage
27
28
### Classification Example
29
30
```python
31
from scikeras.wrappers import KerasClassifier
32
import keras
33
from sklearn.datasets import make_classification
34
from sklearn.model_selection import cross_val_score
35
36
# Define a model building function
37
def create_classifier():
38
model = keras.Sequential([
39
keras.layers.Dense(100, activation='relu', input_dim=20),
40
keras.layers.Dropout(0.5),
41
keras.layers.Dense(50, activation='relu'),
42
keras.layers.Dense(1, activation='sigmoid')
43
])
44
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
45
return model
46
47
# Create classifier wrapper
48
clf = KerasClassifier(
49
model=create_classifier,
50
epochs=10,
51
batch_size=32,
52
verbose=0
53
)
54
55
# Use like any scikit-learn classifier
56
X, y = make_classification(n_samples=1000, n_features=20)
57
scores = cross_val_score(clf, X, y, cv=3)
58
print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")
59
```
60
61
### Regression Example
62
63
```python
64
from scikeras.wrappers import KerasRegressor
65
import keras
66
from sklearn.datasets import make_regression
67
from sklearn.model_selection import cross_val_score
68
69
# Define a model building function
70
def create_regressor():
71
model = keras.Sequential([
72
keras.layers.Dense(64, activation='relu', input_dim=10),
73
keras.layers.Dense(32, activation='relu'),
74
keras.layers.Dense(1)
75
])
76
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
77
return model
78
79
# Create regressor wrapper
80
reg = KerasRegressor(
81
model=create_regressor,
82
epochs=50,
83
batch_size=32,
84
verbose=0
85
)
86
87
# Use like any scikit-learn regressor
88
X, y = make_regression(n_samples=1000, n_features=10)
89
scores = cross_val_score(reg, X, y, cv=3, scoring='neg_mean_squared_error')
90
print(f"MSE: {-scores.mean():.3f} (+/- {scores.std() * 2:.3f})")
91
```
92
93
## Architecture
94
95
SciKeras provides a bridge between Keras/TensorFlow and scikit-learn through wrapper classes that implement the scikit-learn estimator interface:
96
97
- **BaseWrapper**: Abstract base class providing core scikit-learn compatibility
98
- **KerasClassifier**: Classification wrapper with probability prediction support
99
- **KerasRegressor**: Regression wrapper with R² scoring
100
- **Data Transformers**: Automatic preprocessing for different target types
101
- **Parameter Routing**: Double underscore notation for nested configuration
102
- **Serialization Support**: Pickle/joblib compatibility through custom reducers
103
104
This design enables Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, cross-validation, and all standard scikit-learn tools.
105
106
## Capabilities
107
108
### Wrapper Classes
109
110
Core wrapper classes that provide scikit-learn compatibility for Keras models, supporting both classification and regression tasks with automatic data preprocessing.
111
112
```python { .api }
113
class BaseWrapper(BaseEstimator):
114
def __init__(self, model=None, *, optimizer='rmsprop', loss=None,
115
metrics=None, batch_size=None, validation_batch_size=None,
116
verbose=1, callbacks=None, validation_split=0.0,
117
shuffle=True, run_eagerly=None, epochs=1, **kwargs): ...
118
def fit(self, X, y, *, sample_weight=None, **kwargs): ...
119
def predict(self, X, **kwargs): ...
120
def score(self, X, y, *, sample_weight=None): ...
121
122
class KerasClassifier(BaseWrapper, ClassifierMixin):
123
def predict_proba(self, X, **kwargs): ...
124
125
class KerasRegressor(BaseWrapper, RegressorMixin): ...
126
```
127
128
[Wrapper Classes](./wrappers.md)
129
130
### Utility Functions
131
132
Helper functions for normalizing Keras loss and metric names to standardized formats compatible with string-based configuration.
133
134
```python { .api }
135
def loss_name(loss):
136
"""
137
Retrieve standardized loss function name.
138
139
Args:
140
loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier
141
142
Returns:
143
str: Standardized loss name in snake_case
144
"""
145
146
def metric_name(metric):
147
"""
148
Retrieve standardized metric function name.
149
150
Args:
151
metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier
152
153
Returns:
154
str: Standardized metric name
155
"""
156
```
157
158
[Utility Functions](./utils.md)
159
160
### Data Transformers
161
162
Transformer classes for preprocessing targets and features to ensure compatibility between scikit-learn and Keras data formats.
163
164
```python { .api }
165
class TargetReshaper(BaseEstimator, TransformerMixin):
166
def fit(self, y): ...
167
def transform(self, y): ...
168
def inverse_transform(self, y): ...
169
170
class ClassifierLabelEncoder(BaseEstimator, TransformerMixin):
171
def fit(self, y): ...
172
def transform(self, y): ...
173
def inverse_transform(self, y_transformed, return_proba=False): ...
174
175
class RegressorTargetEncoder(BaseEstimator, TransformerMixin):
176
def fit(self, y): ...
177
def transform(self, y): ...
178
def inverse_transform(self, y): ...
179
```
180
181
[Data Transformers](./transformers.md)
182
183
### Random State Management
184
185
Context manager for ensuring reproducible results across Python, NumPy, and TensorFlow random number generators.
186
187
```python { .api }
188
@contextmanager
189
def tensorflow_random_state(seed):
190
"""
191
Context manager for reproducible random state.
192
193
Args:
194
seed (int): Random seed for reproducibility
195
196
Yields:
197
None: Context for reproducible operations
198
"""
199
```
200
201
[Random State Management](./random-state.md)