or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

tessl/pypi-scikeras

Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/scikeras@0.13.x

To install, run

npx @tessl/cli install tessl/pypi-scikeras@0.13.0

0

# SciKeras

1

2

A Python library providing Scikit-Learn compatible wrappers for Keras models, enabling seamless integration of deep learning models into scikit-learn workflows. SciKeras serves as a modern replacement for the deprecated `tf.keras.wrappers.scikit_learn`.

3

4

## Package Information

5

6

- **Package Name**: scikeras

7

- **Package Type**: pypi

8

- **Language**: Python

9

- **Installation**: `pip install scikeras[tensorflow]` or `pip install scikeras` (requires separate TensorFlow installation)

10

11

## Core Imports

12

13

```python

14

from scikeras.wrappers import KerasClassifier, KerasRegressor

15

```

16

17

For advanced usage:

18

19

```python

20

from scikeras.wrappers import BaseWrapper

21

from scikeras.utils import loss_name, metric_name

22

from scikeras.utils.transformers import ClassifierLabelEncoder, TargetReshaper

23

from scikeras.utils.random_state import tensorflow_random_state

24

```

25

26

## Basic Usage

27

28

### Classification Example

29

30

```python

31

from scikeras.wrappers import KerasClassifier

32

import keras

33

from sklearn.datasets import make_classification

34

from sklearn.model_selection import cross_val_score

35

36

# Define a model building function

37

def create_classifier():

38

model = keras.Sequential([

39

keras.layers.Dense(100, activation='relu', input_dim=20),

40

keras.layers.Dropout(0.5),

41

keras.layers.Dense(50, activation='relu'),

42

keras.layers.Dense(1, activation='sigmoid')

43

])

44

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

45

return model

46

47

# Create classifier wrapper

48

clf = KerasClassifier(

49

model=create_classifier,

50

epochs=10,

51

batch_size=32,

52

verbose=0

53

)

54

55

# Use like any scikit-learn classifier

56

X, y = make_classification(n_samples=1000, n_features=20)

57

scores = cross_val_score(clf, X, y, cv=3)

58

print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")

59

```

60

61

### Regression Example

62

63

```python

64

from scikeras.wrappers import KerasRegressor

65

import keras

66

from sklearn.datasets import make_regression

67

from sklearn.model_selection import cross_val_score

68

69

# Define a model building function

70

def create_regressor():

71

model = keras.Sequential([

72

keras.layers.Dense(64, activation='relu', input_dim=10),

73

keras.layers.Dense(32, activation='relu'),

74

keras.layers.Dense(1)

75

])

76

model.compile(optimizer='adam', loss='mse', metrics=['mae'])

77

return model

78

79

# Create regressor wrapper

80

reg = KerasRegressor(

81

model=create_regressor,

82

epochs=50,

83

batch_size=32,

84

verbose=0

85

)

86

87

# Use like any scikit-learn regressor

88

X, y = make_regression(n_samples=1000, n_features=10)

89

scores = cross_val_score(reg, X, y, cv=3, scoring='neg_mean_squared_error')

90

print(f"MSE: {-scores.mean():.3f} (+/- {scores.std() * 2:.3f})")

91

```

92

93

## Architecture

94

95

SciKeras provides a bridge between Keras/TensorFlow and scikit-learn through wrapper classes that implement the scikit-learn estimator interface:

96

97

- **BaseWrapper**: Abstract base class providing core scikit-learn compatibility

98

- **KerasClassifier**: Classification wrapper with probability prediction support

99

- **KerasRegressor**: Regression wrapper with R² scoring

100

- **Data Transformers**: Automatic preprocessing for different target types

101

- **Parameter Routing**: Double underscore notation for nested configuration

102

- **Serialization Support**: Pickle/joblib compatibility through custom reducers

103

104

This design enables Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, cross-validation, and all standard scikit-learn tools.

105

106

## Capabilities

107

108

### Wrapper Classes

109

110

Core wrapper classes that provide scikit-learn compatibility for Keras models, supporting both classification and regression tasks with automatic data preprocessing.

111

112

```python { .api }

113

class BaseWrapper(BaseEstimator):

114

def __init__(self, model=None, *, optimizer='rmsprop', loss=None,

115

metrics=None, batch_size=None, validation_batch_size=None,

116

verbose=1, callbacks=None, validation_split=0.0,

117

shuffle=True, run_eagerly=None, epochs=1, **kwargs): ...

118

def fit(self, X, y, *, sample_weight=None, **kwargs): ...

119

def predict(self, X, **kwargs): ...

120

def score(self, X, y, *, sample_weight=None): ...

121

122

class KerasClassifier(BaseWrapper, ClassifierMixin):

123

def predict_proba(self, X, **kwargs): ...

124

125

class KerasRegressor(BaseWrapper, RegressorMixin): ...

126

```

127

128

[Wrapper Classes](./wrappers.md)

129

130

### Utility Functions

131

132

Helper functions for normalizing Keras loss and metric names to standardized formats compatible with string-based configuration.

133

134

```python { .api }

135

def loss_name(loss):

136

"""

137

Retrieve standardized loss function name.

138

139

Args:

140

loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier

141

142

Returns:

143

str: Standardized loss name in snake_case

144

"""

145

146

def metric_name(metric):

147

"""

148

Retrieve standardized metric function name.

149

150

Args:

151

metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier

152

153

Returns:

154

str: Standardized metric name

155

"""

156

```

157

158

[Utility Functions](./utils.md)

159

160

### Data Transformers

161

162

Transformer classes for preprocessing targets and features to ensure compatibility between scikit-learn and Keras data formats.

163

164

```python { .api }

165

class TargetReshaper(BaseEstimator, TransformerMixin):

166

def fit(self, y): ...

167

def transform(self, y): ...

168

def inverse_transform(self, y): ...

169

170

class ClassifierLabelEncoder(BaseEstimator, TransformerMixin):

171

def fit(self, y): ...

172

def transform(self, y): ...

173

def inverse_transform(self, y_transformed, return_proba=False): ...

174

175

class RegressorTargetEncoder(BaseEstimator, TransformerMixin):

176

def fit(self, y): ...

177

def transform(self, y): ...

178

def inverse_transform(self, y): ...

179

```

180

181

[Data Transformers](./transformers.md)

182

183

### Random State Management

184

185

Context manager for ensuring reproducible results across Python, NumPy, and TensorFlow random number generators.

186

187

```python { .api }

188

@contextmanager

189

def tensorflow_random_state(seed):

190

"""

191

Context manager for reproducible random state.

192

193

Args:

194

seed (int): Random seed for reproducibility

195

196

Yields:

197

None: Context for reproducible operations

198

"""

199

```

200

201

[Random State Management](./random-state.md)