Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.
—
Context manager for ensuring reproducible results across Python, NumPy, and TensorFlow random number generators. This utility enables deterministic training and evaluation for scientific reproducibility and debugging.
Context manager that sets and restores random state across all major random number generators used in machine learning workflows.
@contextmanager
def tensorflow_random_state(seed):
"""
Context manager for reproducible random state across all generators.
Args:
seed (int): Random seed value for reproducibility
Yields:
None: Context for reproducible operations
Note:
This context manager:
- Sets Python's random module seed
- Sets NumPy's random seed
- Sets TensorFlow's random seed
- Enables TensorFlow's deterministic operations
- Restores all original states when exiting
"""from scikeras.utils.random_state import tensorflow_random_state
from scikeras.wrappers import KerasClassifier
import keras
import numpy as np
def create_model():
model = keras.Sequential([
keras.layers.Dense(50, activation='relu', input_dim=10),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# Generate sample data
X = np.random.random((100, 10))
y = np.random.randint(0, 2, 100)
# Train with reproducible results
with tensorflow_random_state(42):
clf = KerasClassifier(model=create_model, epochs=10, verbose=0)
clf.fit(X, y)
predictions_1 = clf.predict(X)
# Train again with same seed - should get identical results
with tensorflow_random_state(42):
clf2 = KerasClassifier(model=create_model, epochs=10, verbose=0)
clf2.fit(X, y)
predictions_2 = clf2.predict(X)
# Verify reproducibility
print(f"Predictions match: {np.array_equal(predictions_1, predictions_2)}")from scikeras.utils.random_state import tensorflow_random_state
from scikeras.wrappers import KerasRegressor
from sklearn.model_selection import cross_val_score
import keras
import numpy as np
def create_regressor():
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_dim=5),
keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
return model
# Generate sample data
X = np.random.random((200, 5))
y = np.random.random(200)
# Reproducible cross-validation
with tensorflow_random_state(123):
reg = KerasRegressor(model=create_regressor, epochs=20, verbose=0)
scores_1 = cross_val_score(reg, X, y, cv=5, scoring='neg_mean_squared_error')
# Repeat with same seed
with tensorflow_random_state(123):
reg2 = KerasRegressor(model=create_regressor, epochs=20, verbose=0)
scores_2 = cross_val_score(reg2, X, y, cv=5, scoring='neg_mean_squared_error')
print(f"CV scores match: {np.allclose(scores_1, scores_2)}")
print(f"First run: {scores_1}")
print(f"Second run: {scores_2}")from scikeras.utils.random_state import tensorflow_random_state
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
import keras
import numpy as np
def create_model(units=50, dropout_rate=0.2):
model = keras.Sequential([
keras.layers.Dense(units, activation='relu', input_dim=8),
keras.layers.Dropout(dropout_rate),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# Generate sample data
X = np.random.random((300, 8))
y = np.random.randint(0, 2, 300)
# Reproducible grid search
param_grid = {
'model__units': [25, 50],
'model__dropout_rate': [0.1, 0.3],
'epochs': [5, 10]
}
with tensorflow_random_state(456):
clf = KerasClassifier(model=create_model, verbose=0)
grid_search = GridSearchCV(clf, param_grid, cv=3, scoring='accuracy')
grid_search.fit(X, y)
best_score_1 = grid_search.best_score_
best_params_1 = grid_search.best_params_
# Repeat search with same seed
with tensorflow_random_state(456):
clf2 = KerasClassifier(model=create_model, verbose=0)
grid_search2 = GridSearchCV(clf2, param_grid, cv=3, scoring='accuracy')
grid_search2.fit(X, y)
best_score_2 = grid_search2.best_score_
best_params_2 = grid_search2.best_params_
print(f"Best scores match: {best_score_1 == best_score_2}")
print(f"Best params match: {best_params_1 == best_params_2}")
print(f"Best parameters: {best_params_1}")
print(f"Best score: {best_score_1:.4f}")from scikeras.utils.random_state import tensorflow_random_state
from scikeras.wrappers import KerasClassifier
import keras
import numpy as np
def create_unstable_model():
"""Model that might have training instability."""
model = keras.Sequential([
keras.layers.Dense(100, activation='relu', input_dim=20),
keras.layers.Dense(100, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
# High learning rate might cause instability
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.1),
loss='binary_crossentropy', metrics=['accuracy'])
return model
# Generate challenging dataset
X = np.random.random((500, 20))
y = (X.sum(axis=1) > 10).astype(int)
# Reproduce training behavior for debugging
debug_seed = 789
with tensorflow_random_state(debug_seed):
clf = KerasClassifier(model=create_unstable_model, epochs=50, verbose=1)
try:
clf.fit(X, y)
final_score = clf.score(X, y)
print(f"Training completed. Final score: {final_score:.4f}")
except Exception as e:
print(f"Training failed with error: {e}")
# Reproduce exact same behavior for investigation
print("\\nReproducing the same training run...")
with tensorflow_random_state(debug_seed):
clf2 = KerasClassifier(model=create_unstable_model, epochs=50, verbose=1)
try:
clf2.fit(X, y)
final_score2 = clf2.score(X, y)
print(f"Training completed. Final score: {final_score2:.4f}")
except Exception as e:
print(f"Training failed with error: {e}")from scikeras.utils.random_state import tensorflow_random_state
from scikeras.wrappers import KerasClassifier
import keras
import numpy as np
def create_model():
model = keras.Sequential([
keras.layers.Dense(30, activation='relu', input_dim=10),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# Generate sample data
X = np.random.random((400, 10))
y = np.random.randint(0, 2, 400)
# Train ensemble with different seeds but reproducible results
ensemble_models = []
ensemble_seeds = [100, 200, 300, 400, 500]
for i, seed in enumerate(ensemble_seeds):
print(f"Training ensemble member {i+1} with seed {seed}")
with tensorflow_random_state(seed):
clf = KerasClassifier(model=create_model, epochs=20, verbose=0)
clf.fit(X, y)
ensemble_models.append(clf)
# Make ensemble predictions
ensemble_predictions = []
for model in ensemble_models:
with tensorflow_random_state(42): # Same seed for prediction consistency
pred = model.predict_proba(X[:10])[:, 1] # Get positive class probabilities
ensemble_predictions.append(pred)
# Average predictions
ensemble_avg = np.mean(ensemble_predictions, axis=0)
print(f"\\nEnsemble predictions for first 10 samples:")
print(f"Individual model predictions:")
for i, pred in enumerate(ensemble_predictions):
print(f"Model {i+1}: {pred}")
print(f"Ensemble average: {ensemble_avg}")The context manager preserves and restores:
random.getstate() and random.setstate()np.random.get_state() and np.random.set_state()tf.random.set_seed()tf.config.experimental.enable_op_determinism()TF_DETERMINISTIC_OPSThe function handles both TensorFlow 2.x installations and environments where TensorFlow is not available:
# When TensorFlow is available
with tensorflow_random_state(42):
# Full deterministic behavior
pass
# When TensorFlow is not installed
with tensorflow_random_state(42):
# Still sets Python and NumPy seeds
# TensorFlow operations are no-ops
passEnabling deterministic operations may impact performance:
The context manager is not thread-safe. Use separate seeds for concurrent training:
import threading
from concurrent.futures import ThreadPoolExecutor
def train_with_seed(seed):
with tensorflow_random_state(seed):
# Training code here
pass
# Use different seeds for parallel training
with ThreadPoolExecutor() as executor:
futures = [executor.submit(train_with_seed, seed)
for seed in [100, 200, 300]]Install with Tessl CLI
npx tessl i tessl/pypi-scikeras