CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-scikeras

Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.

Pending
Overview
Eval results
Files

wrappers.mddocs/

Wrapper Classes

Core wrapper classes that provide scikit-learn compatibility for Keras models. These classes implement the scikit-learn estimator interface, enabling Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, and cross-validation.

Capabilities

BaseWrapper

Abstract base class that implements the core scikit-learn estimator API for Keras models. Provides shared functionality between classification and regression wrappers.

class BaseWrapper(BaseEstimator):
    def __init__(
        self,
        model=None,
        *,
        build_fn=None,
        warm_start=False,
        random_state=None,
        optimizer='rmsprop',
        loss=None,
        metrics=None,
        batch_size=None,
        validation_batch_size=None,
        verbose=1,
        callbacks=None,
        validation_split=0.0,
        validation_steps=None,
        validation_freq=1,
        shuffle=True,
        run_eagerly=None,
        epochs=1,
        initial_epoch=0,
        **kwargs
    ):
        """
        Initialize BaseWrapper.
        
        Args:
            model: Union[None, Callable[..., keras.Model], keras.Model] - Keras model or callable that returns compiled model
            build_fn: Union[None, Callable[..., keras.Model], keras.Model] - Deprecated alias for model parameter
            warm_start: bool - Whether to preserve model parameters between fits
            random_state: Union[int, np.random.RandomState, None] - Random seed for reproducibility
            optimizer: Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]] - Optimizer for training
            loss: Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None] - Loss function
            metrics: List of metrics to monitor during training
            batch_size: Union[int, None] - Number of samples per gradient update
            validation_batch_size: Union[int, None] - Batch size for validation
            verbose: int - Verbosity level (0=silent, 1=progress bar, 2=one line per epoch)
            callbacks: List of Keras callbacks
            validation_split: float - Fraction of training data to use for validation
            validation_steps: Union[int, None] - Number of steps to draw from validation generator
            validation_freq: int - Only run validation every N epochs
            shuffle: bool - Whether to shuffle training data
            run_eagerly: Union[bool, None] - Whether to run in eager mode
            epochs: int - Number of training epochs
            initial_epoch: int - Epoch at which to start training
            **kwargs: Additional parameters passed to model building function
        """
    
    def fit(self, X, y, *, sample_weight=None, **kwargs):
        """
        Train the Keras model.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Training data
            y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values
            sample_weight: array-like of shape (n_samples,), optional - Sample weights
            **kwargs: Additional arguments passed to model.fit()
        
        Returns:
            self: Fitted estimator
        """
    
    def partial_fit(self, X, y, *, sample_weight=None, **kwargs):
        """
        Train the model for a single epoch.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Training data
            y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values
            sample_weight: array-like of shape (n_samples,), optional - Sample weights
            **kwargs: Additional arguments passed to model.fit()
        
        Returns:
            self: Fitted estimator
        """
    
    def predict(self, X, **kwargs):
        """
        Make predictions using the trained model.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Input data
            **kwargs: Additional arguments passed to model.predict()
        
        Returns:
            array-like: Predictions
        """
    
    def score(self, X, y, *, sample_weight=None):
        """
        Return the score of the model on the given test data.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Test data
            y: array-like of shape (n_samples,) or (n_samples, n_outputs) - True values
            sample_weight: array-like of shape (n_samples,), optional - Sample weights
        
        Returns:
            float: Model score
        """
    
    def initialize(self, X, y=None):
        """
        Initialize the model without training.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Sample data for initialization
            y: array-like, optional - Sample targets for initialization
        
        Returns:
            self: Initialized estimator
        """
    
    @property
    def current_epoch(self):
        """Get current training epoch."""
    
    @property
    def initialized_(self):
        """Check if model is initialized."""
    
    @property
    def target_encoder(self):
        """Get target transformation pipeline."""
    
    @property
    def feature_encoder(self):
        """Get feature transformation pipeline."""
    
    @property
    def model_(self):
        """Get the instantiated and compiled Keras Model."""
    
    @property
    def history_(self):
        """Get training history dictionary."""
    
    @property
    def n_outputs_expected_(self):
        """Get expected number of outputs."""
    
    @property
    def target_type_(self):
        """Get target type string."""
    
    @property
    def classes_(self):
        """Get class labels (classification only)."""
    
    @property
    def n_classes_(self):
        """Get number of classes (classification only)."""
    
    @property
    def X_shape_(self):
        """Get input data shape from fitting."""
    
    @property
    def y_shape_(self):
        """Get target data shape from fitting."""
    
    @property
    def X_dtype_(self):
        """Get input data dtype from fitting."""
    
    @property
    def y_dtype_(self):
        """Get target data dtype from fitting."""
    
    @property
    def n_features_in_(self):
        """Get number of features seen during fit."""

KerasClassifier

Scikit-learn compatible classifier wrapper for Keras models. Supports binary and multiclass classification with probability predictions.

class KerasClassifier(BaseWrapper, ClassifierMixin):
    def __init__(self, class_weight=None, **kwargs):
        """
        Initialize KerasClassifier.
        
        Args:
            class_weight: dict or 'balanced', optional - Weights for class balancing
            **kwargs: All arguments from BaseWrapper
        """
    
    def fit(self, X, y, *, sample_weight=None, **kwargs):
        """
        Train the classifier.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Training data
            y: array-like of shape (n_samples,) - Target class labels
            sample_weight: array-like of shape (n_samples,), optional - Sample weights
            **kwargs: Additional arguments passed to model.fit()
        
        Returns:
            self: Fitted classifier
        """
    
    def partial_fit(self, X, y, *, classes=None, sample_weight=None, **kwargs):
        """
        Train the classifier for a single epoch.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Training data
            y: array-like of shape (n_samples,) - Target class labels
            classes: array-like of shape (n_classes,), optional - List of all possible classes
            sample_weight: array-like of shape (n_samples,), optional - Sample weights
            **kwargs: Additional arguments passed to model.fit()
        
        Returns:
            self: Fitted classifier
        """
    
    def predict_proba(self, X, **kwargs):
        """
        Predict class probabilities.
        
        Args:
            X: array-like of shape (n_samples, n_features) - Input data
            **kwargs: Additional arguments passed to model.predict()
        
        Returns:
            array-like of shape (n_samples, n_classes): Class probabilities
        """
    
    @property
    def classes_(self):
        """Get class labels."""
    
    @property
    def n_classes_(self):
        """Get number of classes."""

KerasRegressor

Scikit-learn compatible regressor wrapper for Keras models. Uses R² score as the default scoring metric.

class KerasRegressor(BaseWrapper, RegressorMixin):
    def __init__(self, **kwargs):
        """
        Initialize KerasRegressor.
        
        Args:
            **kwargs: All arguments from BaseWrapper
        """

Usage Examples

Basic Classification with Grid Search

from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
import keras

def create_model(units=50, optimizer='adam'):
    model = keras.Sequential([
        keras.layers.Dense(units, activation='relu', input_dim=10),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Create classifier with parameter routing
clf = KerasClassifier(
    model=create_model,
    epochs=10,
    batch_size=32,
    verbose=0
)

# Use with GridSearchCV
param_grid = {
    'model__units': [25, 50, 100],
    'model__optimizer': ['adam', 'sgd'],
    'epochs': [5, 10, 15]
}

grid = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy')
grid.fit(X_train, y_train)

Warm Start Training

from scikeras.wrappers import KerasRegressor

# Enable warm start to preserve model weights between fit calls
reg = KerasRegressor(
    model=create_model,
    epochs=10,
    warm_start=True
)

# Initial training
reg.fit(X_train, y_train)

# Continue training from previous state
reg.set_params(epochs=5)  # Train for 5 more epochs
reg.fit(X_train, y_train)  # Continues from epoch 10

Custom Callbacks

from scikeras.wrappers import KerasClassifier
import keras

# Define custom callbacks
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=5, restore_best_weights=True
)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.2, patience=3, min_lr=0.001
)

clf = KerasClassifier(
    model=create_model,
    epochs=100,
    validation_split=0.2,
    callbacks=[early_stopping, reduce_lr]
)

clf.fit(X_train, y_train)

Parameter Routing

SciKeras implements a sophisticated parameter routing system that enables passing arguments to nested components using double underscore notation. This allows fine-grained control over all aspects of the model creation, compilation, and training process.

Routing Targets

Parameters can be routed to different destinations:

  • model__*: Parameters passed to the model building function
  • compile__*: Parameters passed to model.compile()
  • fit__*: Parameters passed to model.fit()
  • predict__*: Parameters passed to model.predict()

Examples

# Route parameters to model building function
clf = KerasClassifier(model=create_model)
clf.set_params(
    model__units=100,           # Passed to create_model(units=100)
    model__dropout_rate=0.2,    # Passed to create_model(dropout_rate=0.2)
    compile__optimizer='adam',   # Passed to model.compile(optimizer='adam')
    compile__loss='binary_crossentropy',  # Passed to model.compile(loss=...)
    fit__validation_split=0.2,  # Passed to fit(validation_split=0.2)
    fit__callbacks=[early_stop], # Passed to fit(callbacks=...)
    epochs=50                   # Direct parameter to wrapper
)

Nested Routing

Parameters can be routed to nested objects within the routed target:

# Route to optimizer parameters within compile
clf.set_params(
    compile__optimizer__learning_rate=0.001,  # optimizer.learning_rate = 0.001
    compile__optimizer__beta_1=0.9,           # optimizer.beta_1 = 0.9
)

Types

# Model building function signature
ModelBuildingFunction = Callable[..., keras.Model]

# Supported parameter types
OptimizerType = Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]]
LossType = Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None]
MetricsType = Union[List[Union[str, keras.metrics.Metric]], None]
CallbacksType = Union[List[keras.callbacks.Callback], None]

Install with Tessl CLI

npx tessl i tessl/pypi-scikeras

docs

index.md

random-state.md

transformers.md

utils.md

wrappers.md

tile.json