Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.
npx @tessl/cli install tessl/pypi-scikeras@0.13.0A Python library providing Scikit-Learn compatible wrappers for Keras models, enabling seamless integration of deep learning models into scikit-learn workflows. SciKeras serves as a modern replacement for the deprecated tf.keras.wrappers.scikit_learn.
pip install scikeras[tensorflow] or pip install scikeras (requires separate TensorFlow installation)from scikeras.wrappers import KerasClassifier, KerasRegressorFor advanced usage:
from scikeras.wrappers import BaseWrapper
from scikeras.utils import loss_name, metric_name
from scikeras.utils.transformers import ClassifierLabelEncoder, TargetReshaper
from scikeras.utils.random_state import tensorflow_random_statefrom scikeras.wrappers import KerasClassifier
import keras
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
# Define a model building function
def create_classifier():
model = keras.Sequential([
keras.layers.Dense(100, activation='relu', input_dim=20),
keras.layers.Dropout(0.5),
keras.layers.Dense(50, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# Create classifier wrapper
clf = KerasClassifier(
model=create_classifier,
epochs=10,
batch_size=32,
verbose=0
)
# Use like any scikit-learn classifier
X, y = make_classification(n_samples=1000, n_features=20)
scores = cross_val_score(clf, X, y, cv=3)
print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")from scikeras.wrappers import KerasRegressor
import keras
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_score
# Define a model building function
def create_regressor():
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_dim=10),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
return model
# Create regressor wrapper
reg = KerasRegressor(
model=create_regressor,
epochs=50,
batch_size=32,
verbose=0
)
# Use like any scikit-learn regressor
X, y = make_regression(n_samples=1000, n_features=10)
scores = cross_val_score(reg, X, y, cv=3, scoring='neg_mean_squared_error')
print(f"MSE: {-scores.mean():.3f} (+/- {scores.std() * 2:.3f})")SciKeras provides a bridge between Keras/TensorFlow and scikit-learn through wrapper classes that implement the scikit-learn estimator interface:
This design enables Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, cross-validation, and all standard scikit-learn tools.
Core wrapper classes that provide scikit-learn compatibility for Keras models, supporting both classification and regression tasks with automatic data preprocessing.
class BaseWrapper(BaseEstimator):
def __init__(self, model=None, *, optimizer='rmsprop', loss=None,
metrics=None, batch_size=None, validation_batch_size=None,
verbose=1, callbacks=None, validation_split=0.0,
shuffle=True, run_eagerly=None, epochs=1, **kwargs): ...
def fit(self, X, y, *, sample_weight=None, **kwargs): ...
def predict(self, X, **kwargs): ...
def score(self, X, y, *, sample_weight=None): ...
class KerasClassifier(BaseWrapper, ClassifierMixin):
def predict_proba(self, X, **kwargs): ...
class KerasRegressor(BaseWrapper, RegressorMixin): ...Helper functions for normalizing Keras loss and metric names to standardized formats compatible with string-based configuration.
def loss_name(loss):
"""
Retrieve standardized loss function name.
Args:
loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier
Returns:
str: Standardized loss name in snake_case
"""
def metric_name(metric):
"""
Retrieve standardized metric function name.
Args:
metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier
Returns:
str: Standardized metric name
"""Transformer classes for preprocessing targets and features to ensure compatibility between scikit-learn and Keras data formats.
class TargetReshaper(BaseEstimator, TransformerMixin):
def fit(self, y): ...
def transform(self, y): ...
def inverse_transform(self, y): ...
class ClassifierLabelEncoder(BaseEstimator, TransformerMixin):
def fit(self, y): ...
def transform(self, y): ...
def inverse_transform(self, y_transformed, return_proba=False): ...
class RegressorTargetEncoder(BaseEstimator, TransformerMixin):
def fit(self, y): ...
def transform(self, y): ...
def inverse_transform(self, y): ...Context manager for ensuring reproducible results across Python, NumPy, and TensorFlow random number generators.
@contextmanager
def tensorflow_random_state(seed):
"""
Context manager for reproducible random state.
Args:
seed (int): Random seed for reproducibility
Yields:
None: Context for reproducible operations
"""