CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-scikeras

Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.

Pending
Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

SecuritybySnyk

Pending

The risk profile of this skill

Overview
Eval results
Files

SciKeras

A Python library providing Scikit-Learn compatible wrappers for Keras models, enabling seamless integration of deep learning models into scikit-learn workflows. SciKeras serves as a modern replacement for the deprecated tf.keras.wrappers.scikit_learn.

Package Information

  • Package Name: scikeras
  • Package Type: pypi
  • Language: Python
  • Installation: pip install scikeras[tensorflow] or pip install scikeras (requires separate TensorFlow installation)

Core Imports

from scikeras.wrappers import KerasClassifier, KerasRegressor

For advanced usage:

from scikeras.wrappers import BaseWrapper
from scikeras.utils import loss_name, metric_name
from scikeras.utils.transformers import ClassifierLabelEncoder, TargetReshaper
from scikeras.utils.random_state import tensorflow_random_state

Basic Usage

Classification Example

from scikeras.wrappers import KerasClassifier
import keras
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score

# Define a model building function
def create_classifier():
    model = keras.Sequential([
        keras.layers.Dense(100, activation='relu', input_dim=20),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(50, activation='relu'),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Create classifier wrapper
clf = KerasClassifier(
    model=create_classifier,
    epochs=10,
    batch_size=32,
    verbose=0
)

# Use like any scikit-learn classifier
X, y = make_classification(n_samples=1000, n_features=20)
scores = cross_val_score(clf, X, y, cv=3)
print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")

Regression Example

from scikeras.wrappers import KerasRegressor
import keras
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_score

# Define a model building function
def create_regressor():
    model = keras.Sequential([
        keras.layers.Dense(64, activation='relu', input_dim=10),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(1)
    ])
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model

# Create regressor wrapper
reg = KerasRegressor(
    model=create_regressor,
    epochs=50,
    batch_size=32,
    verbose=0
)

# Use like any scikit-learn regressor
X, y = make_regression(n_samples=1000, n_features=10)
scores = cross_val_score(reg, X, y, cv=3, scoring='neg_mean_squared_error')
print(f"MSE: {-scores.mean():.3f} (+/- {scores.std() * 2:.3f})")

Architecture

SciKeras provides a bridge between Keras/TensorFlow and scikit-learn through wrapper classes that implement the scikit-learn estimator interface:

  • BaseWrapper: Abstract base class providing core scikit-learn compatibility
  • KerasClassifier: Classification wrapper with probability prediction support
  • KerasRegressor: Regression wrapper with R² scoring
  • Data Transformers: Automatic preprocessing for different target types
  • Parameter Routing: Double underscore notation for nested configuration
  • Serialization Support: Pickle/joblib compatibility through custom reducers

This design enables Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, cross-validation, and all standard scikit-learn tools.

Capabilities

Wrapper Classes

Core wrapper classes that provide scikit-learn compatibility for Keras models, supporting both classification and regression tasks with automatic data preprocessing.

class BaseWrapper(BaseEstimator):
    def __init__(self, model=None, *, optimizer='rmsprop', loss=None, 
                 metrics=None, batch_size=None, validation_batch_size=None,
                 verbose=1, callbacks=None, validation_split=0.0, 
                 shuffle=True, run_eagerly=None, epochs=1, **kwargs): ...
    def fit(self, X, y, *, sample_weight=None, **kwargs): ...
    def predict(self, X, **kwargs): ...
    def score(self, X, y, *, sample_weight=None): ...

class KerasClassifier(BaseWrapper, ClassifierMixin):
    def predict_proba(self, X, **kwargs): ...

class KerasRegressor(BaseWrapper, RegressorMixin): ...

Wrapper Classes

Utility Functions

Helper functions for normalizing Keras loss and metric names to standardized formats compatible with string-based configuration.

def loss_name(loss):
    """
    Retrieve standardized loss function name.
    
    Args:
        loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier
    
    Returns:
        str: Standardized loss name in snake_case
    """

def metric_name(metric):
    """
    Retrieve standardized metric function name.
    
    Args:
        metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier
    
    Returns:
        str: Standardized metric name
    """

Utility Functions

Data Transformers

Transformer classes for preprocessing targets and features to ensure compatibility between scikit-learn and Keras data formats.

class TargetReshaper(BaseEstimator, TransformerMixin):
    def fit(self, y): ...
    def transform(self, y): ...
    def inverse_transform(self, y): ...

class ClassifierLabelEncoder(BaseEstimator, TransformerMixin):
    def fit(self, y): ...
    def transform(self, y): ...
    def inverse_transform(self, y_transformed, return_proba=False): ...

class RegressorTargetEncoder(BaseEstimator, TransformerMixin):
    def fit(self, y): ...
    def transform(self, y): ...
    def inverse_transform(self, y): ...

Data Transformers

Random State Management

Context manager for ensuring reproducible results across Python, NumPy, and TensorFlow random number generators.

@contextmanager
def tensorflow_random_state(seed):
    """
    Context manager for reproducible random state.
    
    Args:
        seed (int): Random seed for reproducibility
    
    Yields:
        None: Context for reproducible operations
    """

Random State Management

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/scikeras@0.13.x
Publish Source
CLI
Badge
tessl/pypi-scikeras badge