CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-scikeras

Scikit-Learn API wrapper for Keras models enabling seamless integration of deep learning into scikit-learn workflows.

Pending
Overview
Eval results
Files

utils.mddocs/

Utility Functions

Helper functions for normalizing Keras loss and metric names to standardized formats. These utilities enable consistent string-based configuration and improve compatibility between different naming conventions.

Capabilities

Loss Name Normalization

Standardizes loss function names to snake_case format regardless of input format (string, class, or instance).

def loss_name(loss):
    """
    Retrieve standardized loss function name in snake_case format.
    
    Args:
        loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier
            Can be:
            - String shorthand (e.g., "mse", "binary_crossentropy")
            - String class name (e.g., "BinaryCrossentropy")
            - Loss class (e.g., keras.losses.BinaryCrossentropy)
            - Loss instance (e.g., keras.losses.BinaryCrossentropy())
            - Callable loss function
    
    Returns:
        str: Standardized loss name in snake_case format
    
    Raises:
        TypeError: If loss is not a valid type
    """

Metric Name Normalization

Standardizes metric function names for consistent identification and configuration.

def metric_name(metric):
    """
    Retrieve standardized metric function name.
    
    Args:
        metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier
            Can be:
            - String shorthand (e.g., "acc", "accuracy")
            - String class name (e.g., "BinaryAccuracy")
            - Metric class (e.g., keras.metrics.BinaryAccuracy)
            - Metric instance (e.g., keras.metrics.BinaryAccuracy())
            - Callable metric function
    
    Returns:
        str: Standardized metric name
    
    Raises:
        TypeError: If metric is not a valid type
    """

Usage Examples

Loss Name Standardization

from scikeras.utils import loss_name
import keras.losses as losses

# String inputs
print(loss_name("mse"))                    # Output: "mean_squared_error"
print(loss_name("binary_crossentropy"))   # Output: "binary_crossentropy"
print(loss_name("BinaryCrossentropy"))    # Output: "binary_crossentropy"

# Class inputs
print(loss_name(losses.BinaryCrossentropy))  # Output: "binary_crossentropy"
print(loss_name(losses.MeanSquaredError))    # Output: "mean_squared_error"

# Instance inputs
bce_loss = losses.BinaryCrossentropy()
print(loss_name(bce_loss))                   # Output: "binary_crossentropy"

# Function inputs
print(loss_name(losses.binary_crossentropy)) # Output: "binary_crossentropy"

Metric Name Standardization

from scikeras.utils import metric_name
import keras.metrics as metrics

# String inputs
print(metric_name("acc"))                  # Output: "accuracy"
print(metric_name("accuracy"))             # Output: "accuracy"
print(metric_name("BinaryAccuracy"))       # Output: "BinaryAccuracy"

# Class inputs
print(metric_name(metrics.BinaryAccuracy))   # Output: "BinaryAccuracy"
print(metric_name(metrics.Precision))        # Output: "Precision"

# Instance inputs
acc_metric = metrics.BinaryAccuracy()
print(metric_name(acc_metric))               # Output: "BinaryAccuracy"

# Function inputs
print(metric_name(metrics.accuracy))         # Output: "accuracy"

Configuration Validation

from scikeras.utils import loss_name, metric_name
from scikeras.wrappers import KerasClassifier

def validate_config(loss, metrics):
    """Validate and normalize loss and metrics configuration."""
    try:
        normalized_loss = loss_name(loss)
        normalized_metrics = [metric_name(m) for m in metrics]
        print(f"Loss: {normalized_loss}")
        print(f"Metrics: {normalized_metrics}")
        return True
    except TypeError as e:
        print(f"Configuration error: {e}")
        return False

# Validate different configurations
configs = [
    ("binary_crossentropy", ["accuracy", "precision"]),
    ("BinaryCrossentropy", ["acc", "BinaryPrecision"]),
    (losses.BinaryCrossentropy(), [metrics.Accuracy(), metrics.Precision()])
]

for loss, metrics in configs:
    print(f"\\nValidating: loss={loss}, metrics={metrics}")
    validate_config(loss, metrics)

Dynamic Model Configuration

from scikeras.utils import loss_name, metric_name
import keras

def create_configurable_model(loss_config, metrics_config):
    """Create model with normalized loss and metrics."""
    model = keras.Sequential([
        keras.layers.Dense(64, activation='relu', input_dim=10),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    
    # Normalize configurations
    normalized_loss = loss_name(loss_config)
    normalized_metrics = [metric_name(m) for m in metrics_config]
    
    model.compile(
        optimizer='adam',
        loss=normalized_loss,
        metrics=normalized_metrics
    )
    
    return model

# Create models with different configuration formats
model1 = create_configurable_model("bce", ["acc"])
model2 = create_configurable_model("BinaryCrossentropy", ["accuracy", "precision"])
model3 = create_configurable_model(
    keras.losses.BinaryCrossentropy(),
    [keras.metrics.Accuracy(), keras.metrics.Precision()]
)

Implementation Notes

CamelCase to snake_case Conversion

The utilities automatically convert CamelCase class names to snake_case:

  • BinaryCrossentropybinary_crossentropy
  • MeanSquaredErrormean_squared_error
  • CategoricalCrossentropycategorical_crossentropy

Error Handling

Both functions raise TypeError with descriptive messages for invalid inputs:

try:
    loss_name(123)  # Invalid type
except TypeError as e:
    print(e)  # "loss must be a string, a function, an instance of keras.losses.Loss..."

Keras Compatibility

The utilities work with both Keras 2.x and 3.x naming conventions and automatically handle version differences in the underlying Keras API.

Install with Tessl CLI

npx tessl i tessl/pypi-scikeras

docs

index.md

random-state.md

transformers.md

utils.md

wrappers.md

tile.json