or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

index.mdrandom-state.mdtransformers.mdutils.mdwrappers.md
tile.json

tessl/pypi-scikeras

Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/scikeras@0.13.x

To install, run

npx @tessl/cli install tessl/pypi-scikeras@0.13.0

index.mddocs/

SciKeras

A Python library providing Scikit-Learn compatible wrappers for Keras models, enabling seamless integration of deep learning models into scikit-learn workflows. SciKeras serves as a modern replacement for the deprecated tf.keras.wrappers.scikit_learn.

Package Information

  • Package Name: scikeras
  • Package Type: pypi
  • Language: Python
  • Installation: pip install scikeras[tensorflow] or pip install scikeras (requires separate TensorFlow installation)

Core Imports

from scikeras.wrappers import KerasClassifier, KerasRegressor

For advanced usage:

from scikeras.wrappers import BaseWrapper
from scikeras.utils import loss_name, metric_name
from scikeras.utils.transformers import ClassifierLabelEncoder, TargetReshaper
from scikeras.utils.random_state import tensorflow_random_state

Basic Usage

Classification Example

from scikeras.wrappers import KerasClassifier
import keras
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score

# Define a model building function
def create_classifier():
    model = keras.Sequential([
        keras.layers.Dense(100, activation='relu', input_dim=20),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(50, activation='relu'),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Create classifier wrapper
clf = KerasClassifier(
    model=create_classifier,
    epochs=10,
    batch_size=32,
    verbose=0
)

# Use like any scikit-learn classifier
X, y = make_classification(n_samples=1000, n_features=20)
scores = cross_val_score(clf, X, y, cv=3)
print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")

Regression Example

from scikeras.wrappers import KerasRegressor
import keras
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_score

# Define a model building function
def create_regressor():
    model = keras.Sequential([
        keras.layers.Dense(64, activation='relu', input_dim=10),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(1)
    ])
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model

# Create regressor wrapper
reg = KerasRegressor(
    model=create_regressor,
    epochs=50,
    batch_size=32,
    verbose=0
)

# Use like any scikit-learn regressor
X, y = make_regression(n_samples=1000, n_features=10)
scores = cross_val_score(reg, X, y, cv=3, scoring='neg_mean_squared_error')
print(f"MSE: {-scores.mean():.3f} (+/- {scores.std() * 2:.3f})")

Architecture

SciKeras provides a bridge between Keras/TensorFlow and scikit-learn through wrapper classes that implement the scikit-learn estimator interface:

  • BaseWrapper: Abstract base class providing core scikit-learn compatibility
  • KerasClassifier: Classification wrapper with probability prediction support
  • KerasRegressor: Regression wrapper with R² scoring
  • Data Transformers: Automatic preprocessing for different target types
  • Parameter Routing: Double underscore notation for nested configuration
  • Serialization Support: Pickle/joblib compatibility through custom reducers

This design enables Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, cross-validation, and all standard scikit-learn tools.

Capabilities

Wrapper Classes

Core wrapper classes that provide scikit-learn compatibility for Keras models, supporting both classification and regression tasks with automatic data preprocessing.

class BaseWrapper(BaseEstimator):
    def __init__(self, model=None, *, optimizer='rmsprop', loss=None, 
                 metrics=None, batch_size=None, validation_batch_size=None,
                 verbose=1, callbacks=None, validation_split=0.0, 
                 shuffle=True, run_eagerly=None, epochs=1, **kwargs): ...
    def fit(self, X, y, *, sample_weight=None, **kwargs): ...
    def predict(self, X, **kwargs): ...
    def score(self, X, y, *, sample_weight=None): ...

class KerasClassifier(BaseWrapper, ClassifierMixin):
    def predict_proba(self, X, **kwargs): ...

class KerasRegressor(BaseWrapper, RegressorMixin): ...

Wrapper Classes

Utility Functions

Helper functions for normalizing Keras loss and metric names to standardized formats compatible with string-based configuration.

def loss_name(loss):
    """
    Retrieve standardized loss function name.
    
    Args:
        loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier
    
    Returns:
        str: Standardized loss name in snake_case
    """

def metric_name(metric):
    """
    Retrieve standardized metric function name.
    
    Args:
        metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier
    
    Returns:
        str: Standardized metric name
    """

Utility Functions

Data Transformers

Transformer classes for preprocessing targets and features to ensure compatibility between scikit-learn and Keras data formats.

class TargetReshaper(BaseEstimator, TransformerMixin):
    def fit(self, y): ...
    def transform(self, y): ...
    def inverse_transform(self, y): ...

class ClassifierLabelEncoder(BaseEstimator, TransformerMixin):
    def fit(self, y): ...
    def transform(self, y): ...
    def inverse_transform(self, y_transformed, return_proba=False): ...

class RegressorTargetEncoder(BaseEstimator, TransformerMixin):
    def fit(self, y): ...
    def transform(self, y): ...
    def inverse_transform(self, y): ...

Data Transformers

Random State Management

Context manager for ensuring reproducible results across Python, NumPy, and TensorFlow random number generators.

@contextmanager
def tensorflow_random_state(seed):
    """
    Context manager for reproducible random state.
    
    Args:
        seed (int): Random seed for reproducibility
    
    Yields:
        None: Context for reproducible operations
    """

Random State Management