or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

index.mdrandom-state.mdtransformers.mdutils.mdwrappers.md

wrappers.mddocs/

0

# Wrapper Classes

1

2

Core wrapper classes that provide scikit-learn compatibility for Keras models. These classes implement the scikit-learn estimator interface, enabling Keras models to work seamlessly with scikit-learn's ecosystem including GridSearchCV, Pipeline, and cross-validation.

3

4

## Capabilities

5

6

### BaseWrapper

7

8

Abstract base class that implements the core scikit-learn estimator API for Keras models. Provides shared functionality between classification and regression wrappers.

9

10

```python { .api }

11

class BaseWrapper(BaseEstimator):

12

def __init__(

13

self,

14

model=None,

15

*,

16

build_fn=None,

17

warm_start=False,

18

random_state=None,

19

optimizer='rmsprop',

20

loss=None,

21

metrics=None,

22

batch_size=None,

23

validation_batch_size=None,

24

verbose=1,

25

callbacks=None,

26

validation_split=0.0,

27

validation_steps=None,

28

validation_freq=1,

29

shuffle=True,

30

run_eagerly=None,

31

epochs=1,

32

initial_epoch=0,

33

**kwargs

34

):

35

"""

36

Initialize BaseWrapper.

37

38

Args:

39

model: Union[None, Callable[..., keras.Model], keras.Model] - Keras model or callable that returns compiled model

40

build_fn: Union[None, Callable[..., keras.Model], keras.Model] - Deprecated alias for model parameter

41

warm_start: bool - Whether to preserve model parameters between fits

42

random_state: Union[int, np.random.RandomState, None] - Random seed for reproducibility

43

optimizer: Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]] - Optimizer for training

44

loss: Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None] - Loss function

45

metrics: List of metrics to monitor during training

46

batch_size: Union[int, None] - Number of samples per gradient update

47

validation_batch_size: Union[int, None] - Batch size for validation

48

verbose: int - Verbosity level (0=silent, 1=progress bar, 2=one line per epoch)

49

callbacks: List of Keras callbacks

50

validation_split: float - Fraction of training data to use for validation

51

validation_steps: Union[int, None] - Number of steps to draw from validation generator

52

validation_freq: int - Only run validation every N epochs

53

shuffle: bool - Whether to shuffle training data

54

run_eagerly: Union[bool, None] - Whether to run in eager mode

55

epochs: int - Number of training epochs

56

initial_epoch: int - Epoch at which to start training

57

**kwargs: Additional parameters passed to model building function

58

"""

59

60

def fit(self, X, y, *, sample_weight=None, **kwargs):

61

"""

62

Train the Keras model.

63

64

Args:

65

X: array-like of shape (n_samples, n_features) - Training data

66

y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values

67

sample_weight: array-like of shape (n_samples,), optional - Sample weights

68

**kwargs: Additional arguments passed to model.fit()

69

70

Returns:

71

self: Fitted estimator

72

"""

73

74

def partial_fit(self, X, y, *, sample_weight=None, **kwargs):

75

"""

76

Train the model for a single epoch.

77

78

Args:

79

X: array-like of shape (n_samples, n_features) - Training data

80

y: array-like of shape (n_samples,) or (n_samples, n_outputs) - Target values

81

sample_weight: array-like of shape (n_samples,), optional - Sample weights

82

**kwargs: Additional arguments passed to model.fit()

83

84

Returns:

85

self: Fitted estimator

86

"""

87

88

def predict(self, X, **kwargs):

89

"""

90

Make predictions using the trained model.

91

92

Args:

93

X: array-like of shape (n_samples, n_features) - Input data

94

**kwargs: Additional arguments passed to model.predict()

95

96

Returns:

97

array-like: Predictions

98

"""

99

100

def score(self, X, y, *, sample_weight=None):

101

"""

102

Return the score of the model on the given test data.

103

104

Args:

105

X: array-like of shape (n_samples, n_features) - Test data

106

y: array-like of shape (n_samples,) or (n_samples, n_outputs) - True values

107

sample_weight: array-like of shape (n_samples,), optional - Sample weights

108

109

Returns:

110

float: Model score

111

"""

112

113

def initialize(self, X, y=None):

114

"""

115

Initialize the model without training.

116

117

Args:

118

X: array-like of shape (n_samples, n_features) - Sample data for initialization

119

y: array-like, optional - Sample targets for initialization

120

121

Returns:

122

self: Initialized estimator

123

"""

124

125

@property

126

def current_epoch(self):

127

"""Get current training epoch."""

128

129

@property

130

def initialized_(self):

131

"""Check if model is initialized."""

132

133

@property

134

def target_encoder(self):

135

"""Get target transformation pipeline."""

136

137

@property

138

def feature_encoder(self):

139

"""Get feature transformation pipeline."""

140

141

@property

142

def model_(self):

143

"""Get the instantiated and compiled Keras Model."""

144

145

@property

146

def history_(self):

147

"""Get training history dictionary."""

148

149

@property

150

def n_outputs_expected_(self):

151

"""Get expected number of outputs."""

152

153

@property

154

def target_type_(self):

155

"""Get target type string."""

156

157

@property

158

def classes_(self):

159

"""Get class labels (classification only)."""

160

161

@property

162

def n_classes_(self):

163

"""Get number of classes (classification only)."""

164

165

@property

166

def X_shape_(self):

167

"""Get input data shape from fitting."""

168

169

@property

170

def y_shape_(self):

171

"""Get target data shape from fitting."""

172

173

@property

174

def X_dtype_(self):

175

"""Get input data dtype from fitting."""

176

177

@property

178

def y_dtype_(self):

179

"""Get target data dtype from fitting."""

180

181

@property

182

def n_features_in_(self):

183

"""Get number of features seen during fit."""

184

```

185

186

### KerasClassifier

187

188

Scikit-learn compatible classifier wrapper for Keras models. Supports binary and multiclass classification with probability predictions.

189

190

```python { .api }

191

class KerasClassifier(BaseWrapper, ClassifierMixin):

192

def __init__(self, class_weight=None, **kwargs):

193

"""

194

Initialize KerasClassifier.

195

196

Args:

197

class_weight: dict or 'balanced', optional - Weights for class balancing

198

**kwargs: All arguments from BaseWrapper

199

"""

200

201

def fit(self, X, y, *, sample_weight=None, **kwargs):

202

"""

203

Train the classifier.

204

205

Args:

206

X: array-like of shape (n_samples, n_features) - Training data

207

y: array-like of shape (n_samples,) - Target class labels

208

sample_weight: array-like of shape (n_samples,), optional - Sample weights

209

**kwargs: Additional arguments passed to model.fit()

210

211

Returns:

212

self: Fitted classifier

213

"""

214

215

def partial_fit(self, X, y, *, classes=None, sample_weight=None, **kwargs):

216

"""

217

Train the classifier for a single epoch.

218

219

Args:

220

X: array-like of shape (n_samples, n_features) - Training data

221

y: array-like of shape (n_samples,) - Target class labels

222

classes: array-like of shape (n_classes,), optional - List of all possible classes

223

sample_weight: array-like of shape (n_samples,), optional - Sample weights

224

**kwargs: Additional arguments passed to model.fit()

225

226

Returns:

227

self: Fitted classifier

228

"""

229

230

def predict_proba(self, X, **kwargs):

231

"""

232

Predict class probabilities.

233

234

Args:

235

X: array-like of shape (n_samples, n_features) - Input data

236

**kwargs: Additional arguments passed to model.predict()

237

238

Returns:

239

array-like of shape (n_samples, n_classes): Class probabilities

240

"""

241

242

@property

243

def classes_(self):

244

"""Get class labels."""

245

246

@property

247

def n_classes_(self):

248

"""Get number of classes."""

249

250

```

251

252

### KerasRegressor

253

254

Scikit-learn compatible regressor wrapper for Keras models. Uses R² score as the default scoring metric.

255

256

```python { .api }

257

class KerasRegressor(BaseWrapper, RegressorMixin):

258

def __init__(self, **kwargs):

259

"""

260

Initialize KerasRegressor.

261

262

Args:

263

**kwargs: All arguments from BaseWrapper

264

"""

265

```

266

267

## Usage Examples

268

269

### Basic Classification with Grid Search

270

271

```python

272

from scikeras.wrappers import KerasClassifier

273

from sklearn.model_selection import GridSearchCV

274

import keras

275

276

def create_model(units=50, optimizer='adam'):

277

model = keras.Sequential([

278

keras.layers.Dense(units, activation='relu', input_dim=10),

279

keras.layers.Dense(1, activation='sigmoid')

280

])

281

model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

282

return model

283

284

# Create classifier with parameter routing

285

clf = KerasClassifier(

286

model=create_model,

287

epochs=10,

288

batch_size=32,

289

verbose=0

290

)

291

292

# Use with GridSearchCV

293

param_grid = {

294

'model__units': [25, 50, 100],

295

'model__optimizer': ['adam', 'sgd'],

296

'epochs': [5, 10, 15]

297

}

298

299

grid = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy')

300

grid.fit(X_train, y_train)

301

```

302

303

### Warm Start Training

304

305

```python

306

from scikeras.wrappers import KerasRegressor

307

308

# Enable warm start to preserve model weights between fit calls

309

reg = KerasRegressor(

310

model=create_model,

311

epochs=10,

312

warm_start=True

313

)

314

315

# Initial training

316

reg.fit(X_train, y_train)

317

318

# Continue training from previous state

319

reg.set_params(epochs=5) # Train for 5 more epochs

320

reg.fit(X_train, y_train) # Continues from epoch 10

321

```

322

323

### Custom Callbacks

324

325

```python

326

from scikeras.wrappers import KerasClassifier

327

import keras

328

329

# Define custom callbacks

330

early_stopping = keras.callbacks.EarlyStopping(

331

monitor='val_loss', patience=5, restore_best_weights=True

332

)

333

334

reduce_lr = keras.callbacks.ReduceLROnPlateau(

335

monitor='val_loss', factor=0.2, patience=3, min_lr=0.001

336

)

337

338

clf = KerasClassifier(

339

model=create_model,

340

epochs=100,

341

validation_split=0.2,

342

callbacks=[early_stopping, reduce_lr]

343

)

344

345

clf.fit(X_train, y_train)

346

```

347

348

### Parameter Routing

349

350

SciKeras implements a sophisticated parameter routing system that enables passing arguments to nested components using double underscore notation. This allows fine-grained control over all aspects of the model creation, compilation, and training process.

351

352

#### Routing Targets

353

354

Parameters can be routed to different destinations:

355

356

- `model__*`: Parameters passed to the model building function

357

- `compile__*`: Parameters passed to model.compile()

358

- `fit__*`: Parameters passed to model.fit()

359

- `predict__*`: Parameters passed to model.predict()

360

361

#### Examples

362

363

```python

364

# Route parameters to model building function

365

clf = KerasClassifier(model=create_model)

366

clf.set_params(

367

model__units=100, # Passed to create_model(units=100)

368

model__dropout_rate=0.2, # Passed to create_model(dropout_rate=0.2)

369

compile__optimizer='adam', # Passed to model.compile(optimizer='adam')

370

compile__loss='binary_crossentropy', # Passed to model.compile(loss=...)

371

fit__validation_split=0.2, # Passed to fit(validation_split=0.2)

372

fit__callbacks=[early_stop], # Passed to fit(callbacks=...)

373

epochs=50 # Direct parameter to wrapper

374

)

375

```

376

377

#### Nested Routing

378

379

Parameters can be routed to nested objects within the routed target:

380

381

```python

382

# Route to optimizer parameters within compile

383

clf.set_params(

384

compile__optimizer__learning_rate=0.001, # optimizer.learning_rate = 0.001

385

compile__optimizer__beta_1=0.9, # optimizer.beta_1 = 0.9

386

)

387

```

388

389

## Types

390

391

```python { .api }

392

# Model building function signature

393

ModelBuildingFunction = Callable[..., keras.Model]

394

395

# Supported parameter types

396

OptimizerType = Union[str, keras.optimizers.Optimizer, Type[keras.optimizers.Optimizer]]

397

LossType = Union[str, keras.losses.Loss, Type[keras.losses.Loss], Callable, None]

398

MetricsType = Union[List[Union[str, keras.metrics.Metric]], None]

399

CallbacksType = Union[List[keras.callbacks.Callback], None]

400

```