0
# Utility Functions
1
2
Helper functions for normalizing Keras loss and metric names to standardized formats. These utilities enable consistent string-based configuration and improve compatibility between different naming conventions.
3
4
## Capabilities
5
6
### Loss Name Normalization
7
8
Standardizes loss function names to snake_case format regardless of input format (string, class, or instance).
9
10
```python { .api }
11
def loss_name(loss):
12
"""
13
Retrieve standardized loss function name in snake_case format.
14
15
Args:
16
loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier
17
Can be:
18
- String shorthand (e.g., "mse", "binary_crossentropy")
19
- String class name (e.g., "BinaryCrossentropy")
20
- Loss class (e.g., keras.losses.BinaryCrossentropy)
21
- Loss instance (e.g., keras.losses.BinaryCrossentropy())
22
- Callable loss function
23
24
Returns:
25
str: Standardized loss name in snake_case format
26
27
Raises:
28
TypeError: If loss is not a valid type
29
"""
30
```
31
32
### Metric Name Normalization
33
34
Standardizes metric function names for consistent identification and configuration.
35
36
```python { .api }
37
def metric_name(metric):
38
"""
39
Retrieve standardized metric function name.
40
41
Args:
42
metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier
43
Can be:
44
- String shorthand (e.g., "acc", "accuracy")
45
- String class name (e.g., "BinaryAccuracy")
46
- Metric class (e.g., keras.metrics.BinaryAccuracy)
47
- Metric instance (e.g., keras.metrics.BinaryAccuracy())
48
- Callable metric function
49
50
Returns:
51
str: Standardized metric name
52
53
Raises:
54
TypeError: If metric is not a valid type
55
"""
56
```
57
58
## Usage Examples
59
60
### Loss Name Standardization
61
62
```python
63
from scikeras.utils import loss_name
64
import keras.losses as losses
65
66
# String inputs
67
print(loss_name("mse")) # Output: "mean_squared_error"
68
print(loss_name("binary_crossentropy")) # Output: "binary_crossentropy"
69
print(loss_name("BinaryCrossentropy")) # Output: "binary_crossentropy"
70
71
# Class inputs
72
print(loss_name(losses.BinaryCrossentropy)) # Output: "binary_crossentropy"
73
print(loss_name(losses.MeanSquaredError)) # Output: "mean_squared_error"
74
75
# Instance inputs
76
bce_loss = losses.BinaryCrossentropy()
77
print(loss_name(bce_loss)) # Output: "binary_crossentropy"
78
79
# Function inputs
80
print(loss_name(losses.binary_crossentropy)) # Output: "binary_crossentropy"
81
```
82
83
### Metric Name Standardization
84
85
```python
86
from scikeras.utils import metric_name
87
import keras.metrics as metrics
88
89
# String inputs
90
print(metric_name("acc")) # Output: "accuracy"
91
print(metric_name("accuracy")) # Output: "accuracy"
92
print(metric_name("BinaryAccuracy")) # Output: "BinaryAccuracy"
93
94
# Class inputs
95
print(metric_name(metrics.BinaryAccuracy)) # Output: "BinaryAccuracy"
96
print(metric_name(metrics.Precision)) # Output: "Precision"
97
98
# Instance inputs
99
acc_metric = metrics.BinaryAccuracy()
100
print(metric_name(acc_metric)) # Output: "BinaryAccuracy"
101
102
# Function inputs
103
print(metric_name(metrics.accuracy)) # Output: "accuracy"
104
```
105
106
### Configuration Validation
107
108
```python
109
from scikeras.utils import loss_name, metric_name
110
from scikeras.wrappers import KerasClassifier
111
112
def validate_config(loss, metrics):
113
"""Validate and normalize loss and metrics configuration."""
114
try:
115
normalized_loss = loss_name(loss)
116
normalized_metrics = [metric_name(m) for m in metrics]
117
print(f"Loss: {normalized_loss}")
118
print(f"Metrics: {normalized_metrics}")
119
return True
120
except TypeError as e:
121
print(f"Configuration error: {e}")
122
return False
123
124
# Validate different configurations
125
configs = [
126
("binary_crossentropy", ["accuracy", "precision"]),
127
("BinaryCrossentropy", ["acc", "BinaryPrecision"]),
128
(losses.BinaryCrossentropy(), [metrics.Accuracy(), metrics.Precision()])
129
]
130
131
for loss, metrics in configs:
132
print(f"\\nValidating: loss={loss}, metrics={metrics}")
133
validate_config(loss, metrics)
134
```
135
136
### Dynamic Model Configuration
137
138
```python
139
from scikeras.utils import loss_name, metric_name
140
import keras
141
142
def create_configurable_model(loss_config, metrics_config):
143
"""Create model with normalized loss and metrics."""
144
model = keras.Sequential([
145
keras.layers.Dense(64, activation='relu', input_dim=10),
146
keras.layers.Dense(1, activation='sigmoid')
147
])
148
149
# Normalize configurations
150
normalized_loss = loss_name(loss_config)
151
normalized_metrics = [metric_name(m) for m in metrics_config]
152
153
model.compile(
154
optimizer='adam',
155
loss=normalized_loss,
156
metrics=normalized_metrics
157
)
158
159
return model
160
161
# Create models with different configuration formats
162
model1 = create_configurable_model("bce", ["acc"])
163
model2 = create_configurable_model("BinaryCrossentropy", ["accuracy", "precision"])
164
model3 = create_configurable_model(
165
keras.losses.BinaryCrossentropy(),
166
[keras.metrics.Accuracy(), keras.metrics.Precision()]
167
)
168
```
169
170
## Implementation Notes
171
172
### CamelCase to snake_case Conversion
173
174
The utilities automatically convert CamelCase class names to snake_case:
175
176
- `BinaryCrossentropy` → `binary_crossentropy`
177
- `MeanSquaredError` → `mean_squared_error`
178
- `CategoricalCrossentropy` → `categorical_crossentropy`
179
180
### Error Handling
181
182
Both functions raise `TypeError` with descriptive messages for invalid inputs:
183
184
```python
185
try:
186
loss_name(123) # Invalid type
187
except TypeError as e:
188
print(e) # "loss must be a string, a function, an instance of keras.losses.Loss..."
189
```
190
191
### Keras Compatibility
192
193
The utilities work with both Keras 2.x and 3.x naming conventions and automatically handle version differences in the underlying Keras API.