Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.
—
Core wrapper classes that provide scikit-learn compatibility for Keras models. These classes implement the scikit-learn estimator interface, enabling Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, and cross-validation.
Abstract base class that implements the core scikit-learn estimator API for Keras models. Provides shared functionality between classification and regression wrappers.
class BaseWrapper(BaseEstimator):
def __init__(
self,
model=None,
*,
build_fn=None,
warm_start=False,
random_state=None,
optimizer='rmsprop',
loss=None,
metrics=None,
batch_size=None,
validation_batch_size=None,
verbose=1,
callbacks=None,
validation_split=0.0,
validation_steps=None,
validation_freq=1,
shuffle=True,
run_eagerly=None,
epochs=1,
initial_epoch=0,
**kwargs
):
"""
Initialize BaseWrapper.
Args:
model: Union[None, Callable[..., keras.Model], keras.Model] - Keras model or callable that returns compiled model
build_fn: Union[None, Callable[..., keras.Model], keras.Model] - Deprecated alias for model parameter
warm_start: bool - Whether to preserve model parameters between fits
random_state: Union[int, np.random.RandomState, None] - Random seed for reproducibility
optimizer: Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]] - Optimizer for training
loss: Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None] - Loss function
metrics: List of metrics to monitor during training
batch_size: Union[int, None] - Number of samples per gradient update
validation_batch_size: Union[int, None] - Batch size for validation
verbose: int - Verbosity level (0=silent, 1=progress bar, 2=one line per epoch)
callbacks: List of Keras callbacks
validation_split: float - Fraction of training data to use for validation
validation_steps: Union[int, None] - Number of steps to draw from validation generator
validation_freq: int - Only run validation every N epochs
shuffle: bool - Whether to shuffle training data
run_eagerly: Union[bool, None] - Whether to run in eager mode
epochs: int - Number of training epochs
initial_epoch: int - Epoch at which to start training
**kwargs: Additional parameters passed to model building function
"""
def fit(self, X, y, *, sample_weight=None, **kwargs):
"""
Train the Keras model.
Args:
X: array-like of shape (n_samples, n_features) - Training data
y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values
sample_weight: array-like of shape (n_samples,), optional - Sample weights
**kwargs: Additional arguments passed to model.fit()
Returns:
self: Fitted estimator
"""
def partial_fit(self, X, y, *, sample_weight=None, **kwargs):
"""
Train the model for a single epoch.
Args:
X: array-like of shape (n_samples, n_features) - Training data
y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values
sample_weight: array-like of shape (n_samples,), optional - Sample weights
**kwargs: Additional arguments passed to model.fit()
Returns:
self: Fitted estimator
"""
def predict(self, X, **kwargs):
"""
Make predictions using the trained model.
Args:
X: array-like of shape (n_samples, n_features) - Input data
**kwargs: Additional arguments passed to model.predict()
Returns:
array-like: Predictions
"""
def score(self, X, y, *, sample_weight=None):
"""
Return the score of the model on the given test data.
Args:
X: array-like of shape (n_samples, n_features) - Test data
y: array-like of shape (n_samples,) or (n_samples, n_outputs) - True values
sample_weight: array-like of shape (n_samples,), optional - Sample weights
Returns:
float: Model score
"""
def initialize(self, X, y=None):
"""
Initialize the model without training.
Args:
X: array-like of shape (n_samples, n_features) - Sample data for initialization
y: array-like, optional - Sample targets for initialization
Returns:
self: Initialized estimator
"""
@property
def current_epoch(self):
"""Get current training epoch."""
@property
def initialized_(self):
"""Check if model is initialized."""
@property
def target_encoder(self):
"""Get target transformation pipeline."""
@property
def feature_encoder(self):
"""Get feature transformation pipeline."""
@property
def model_(self):
"""Get the instantiated and compiled Keras Model."""
@property
def history_(self):
"""Get training history dictionary."""
@property
def n_outputs_expected_(self):
"""Get expected number of outputs."""
@property
def target_type_(self):
"""Get target type string."""
@property
def classes_(self):
"""Get class labels (classification only)."""
@property
def n_classes_(self):
"""Get number of classes (classification only)."""
@property
def X_shape_(self):
"""Get input data shape from fitting."""
@property
def y_shape_(self):
"""Get target data shape from fitting."""
@property
def X_dtype_(self):
"""Get input data dtype from fitting."""
@property
def y_dtype_(self):
"""Get target data dtype from fitting."""
@property
def n_features_in_(self):
"""Get number of features seen during fit."""Scikit-learn compatible classifier wrapper for Keras models. Supports binary and multiclass classification with probability predictions.
class KerasClassifier(BaseWrapper, ClassifierMixin):
def __init__(self, class_weight=None, **kwargs):
"""
Initialize KerasClassifier.
Args:
class_weight: dict or 'balanced', optional - Weights for class balancing
**kwargs: All arguments from BaseWrapper
"""
def fit(self, X, y, *, sample_weight=None, **kwargs):
"""
Train the classifier.
Args:
X: array-like of shape (n_samples, n_features) - Training data
y: array-like of shape (n_samples,) - Target class labels
sample_weight: array-like of shape (n_samples,), optional - Sample weights
**kwargs: Additional arguments passed to model.fit()
Returns:
self: Fitted classifier
"""
def partial_fit(self, X, y, *, classes=None, sample_weight=None, **kwargs):
"""
Train the classifier for a single epoch.
Args:
X: array-like of shape (n_samples, n_features) - Training data
y: array-like of shape (n_samples,) - Target class labels
classes: array-like of shape (n_classes,), optional - List of all possible classes
sample_weight: array-like of shape (n_samples,), optional - Sample weights
**kwargs: Additional arguments passed to model.fit()
Returns:
self: Fitted classifier
"""
def predict_proba(self, X, **kwargs):
"""
Predict class probabilities.
Args:
X: array-like of shape (n_samples, n_features) - Input data
**kwargs: Additional arguments passed to model.predict()
Returns:
array-like of shape (n_samples, n_classes): Class probabilities
"""
@property
def classes_(self):
"""Get class labels."""
@property
def n_classes_(self):
"""Get number of classes."""Scikit-learn compatible regressor wrapper for Keras models. Uses R² score as the default scoring metric.
class KerasRegressor(BaseWrapper, RegressorMixin):
def __init__(self, **kwargs):
"""
Initialize KerasRegressor.
Args:
**kwargs: All arguments from BaseWrapper
"""from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
import keras
def create_model(units=50, optimizer='adam'):
model = keras.Sequential([
keras.layers.Dense(units, activation='relu', input_dim=10),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
return model
# Create classifier with parameter routing
clf = KerasClassifier(
model=create_model,
epochs=10,
batch_size=32,
verbose=0
)
# Use with GridSearchCV
param_grid = {
'model__units': [25, 50, 100],
'model__optimizer': ['adam', 'sgd'],
'epochs': [5, 10, 15]
}
grid = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy')
grid.fit(X_train, y_train)from scikeras.wrappers import KerasRegressor
# Enable warm start to preserve model weights between fit calls
reg = KerasRegressor(
model=create_model,
epochs=10,
warm_start=True
)
# Initial training
reg.fit(X_train, y_train)
# Continue training from previous state
reg.set_params(epochs=5) # Train for 5 more epochs
reg.fit(X_train, y_train) # Continues from epoch 10from scikeras.wrappers import KerasClassifier
import keras
# Define custom callbacks
early_stopping = keras.callbacks.EarlyStopping(
monitor='val_loss', patience=5, restore_best_weights=True
)
reduce_lr = keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.2, patience=3, min_lr=0.001
)
clf = KerasClassifier(
model=create_model,
epochs=100,
validation_split=0.2,
callbacks=[early_stopping, reduce_lr]
)
clf.fit(X_train, y_train)SciKeras implements a sophisticated parameter routing system that enables passing arguments to nested components using double underscore notation. This allows fine-grained control over all aspects of the model creation, compilation, and training process.
Parameters can be routed to different destinations:
model__*: Parameters passed to the model building functioncompile__*: Parameters passed to model.compile()fit__*: Parameters passed to model.fit()predict__*: Parameters passed to model.predict()# Route parameters to model building function
clf = KerasClassifier(model=create_model)
clf.set_params(
model__units=100, # Passed to create_model(units=100)
model__dropout_rate=0.2, # Passed to create_model(dropout_rate=0.2)
compile__optimizer='adam', # Passed to model.compile(optimizer='adam')
compile__loss='binary_crossentropy', # Passed to model.compile(loss=...)
fit__validation_split=0.2, # Passed to fit(validation_split=0.2)
fit__callbacks=[early_stop], # Passed to fit(callbacks=...)
epochs=50 # Direct parameter to wrapper
)Parameters can be routed to nested objects within the routed target:
# Route to optimizer parameters within compile
clf.set_params(
compile__optimizer__learning_rate=0.001, # optimizer.learning_rate = 0.001
compile__optimizer__beta_1=0.9, # optimizer.beta_1 = 0.9
)# Model building function signature
ModelBuildingFunction = Callable[..., keras.Model]
# Supported parameter types
OptimizerType = Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]]
LossType = Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None]
MetricsType = Union[List[Union[str, keras.metrics.Metric]], None]
CallbacksType = Union[List[keras.callbacks.Callback], None]Install with Tessl CLI
npx tessl i tessl/pypi-scikeras