or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

index.mdrandom-state.mdtransformers.mdutils.mdwrappers.md

utils.mddocs/

0

# Utility Functions

1

2

Helper functions for normalizing Keras loss and metric names to standardized formats. These utilities enable consistent string-based configuration and improve compatibility between different naming conventions.

3

4

## Capabilities

5

6

### Loss Name Normalization

7

8

Standardizes loss function names to snake_case format regardless of input format (string, class, or instance).

9

10

```python { .api }

11

def loss_name(loss):

12

"""

13

Retrieve standardized loss function name in snake_case format.

14

15

Args:

16

loss: Union[str, keras.losses.Loss, Callable] - Loss function identifier

17

Can be:

18

- String shorthand (e.g., "mse", "binary_crossentropy")

19

- String class name (e.g., "BinaryCrossentropy")

20

- Loss class (e.g., keras.losses.BinaryCrossentropy)

21

- Loss instance (e.g., keras.losses.BinaryCrossentropy())

22

- Callable loss function

23

24

Returns:

25

str: Standardized loss name in snake_case format

26

27

Raises:

28

TypeError: If loss is not a valid type

29

"""

30

```

31

32

### Metric Name Normalization

33

34

Standardizes metric function names for consistent identification and configuration.

35

36

```python { .api }

37

def metric_name(metric):

38

"""

39

Retrieve standardized metric function name.

40

41

Args:

42

metric: Union[str, keras.metrics.Metric, Callable] - Metric function identifier

43

Can be:

44

- String shorthand (e.g., "acc", "accuracy")

45

- String class name (e.g., "BinaryAccuracy")

46

- Metric class (e.g., keras.metrics.BinaryAccuracy)

47

- Metric instance (e.g., keras.metrics.BinaryAccuracy())

48

- Callable metric function

49

50

Returns:

51

str: Standardized metric name

52

53

Raises:

54

TypeError: If metric is not a valid type

55

"""

56

```

57

58

## Usage Examples

59

60

### Loss Name Standardization

61

62

```python

63

from scikeras.utils import loss_name

64

import keras.losses as losses

65

66

# String inputs

67

print(loss_name("mse")) # Output: "mean_squared_error"

68

print(loss_name("binary_crossentropy")) # Output: "binary_crossentropy"

69

print(loss_name("BinaryCrossentropy")) # Output: "binary_crossentropy"

70

71

# Class inputs

72

print(loss_name(losses.BinaryCrossentropy)) # Output: "binary_crossentropy"

73

print(loss_name(losses.MeanSquaredError)) # Output: "mean_squared_error"

74

75

# Instance inputs

76

bce_loss = losses.BinaryCrossentropy()

77

print(loss_name(bce_loss)) # Output: "binary_crossentropy"

78

79

# Function inputs

80

print(loss_name(losses.binary_crossentropy)) # Output: "binary_crossentropy"

81

```

82

83

### Metric Name Standardization

84

85

```python

86

from scikeras.utils import metric_name

87

import keras.metrics as metrics

88

89

# String inputs

90

print(metric_name("acc")) # Output: "accuracy"

91

print(metric_name("accuracy")) # Output: "accuracy"

92

print(metric_name("BinaryAccuracy")) # Output: "BinaryAccuracy"

93

94

# Class inputs

95

print(metric_name(metrics.BinaryAccuracy)) # Output: "BinaryAccuracy"

96

print(metric_name(metrics.Precision)) # Output: "Precision"

97

98

# Instance inputs

99

acc_metric = metrics.BinaryAccuracy()

100

print(metric_name(acc_metric)) # Output: "BinaryAccuracy"

101

102

# Function inputs

103

print(metric_name(metrics.accuracy)) # Output: "accuracy"

104

```

105

106

### Configuration Validation

107

108

```python

109

from scikeras.utils import loss_name, metric_name

110

from scikeras.wrappers import KerasClassifier

111

112

def validate_config(loss, metrics):

113

"""Validate and normalize loss and metrics configuration."""

114

try:

115

normalized_loss = loss_name(loss)

116

normalized_metrics = [metric_name(m) for m in metrics]

117

print(f"Loss: {normalized_loss}")

118

print(f"Metrics: {normalized_metrics}")

119

return True

120

except TypeError as e:

121

print(f"Configuration error: {e}")

122

return False

123

124

# Validate different configurations

125

configs = [

126

("binary_crossentropy", ["accuracy", "precision"]),

127

("BinaryCrossentropy", ["acc", "BinaryPrecision"]),

128

(losses.BinaryCrossentropy(), [metrics.Accuracy(), metrics.Precision()])

129

]

130

131

for loss, metrics in configs:

132

print(f"\\nValidating: loss={loss}, metrics={metrics}")

133

validate_config(loss, metrics)

134

```

135

136

### Dynamic Model Configuration

137

138

```python

139

from scikeras.utils import loss_name, metric_name

140

import keras

141

142

def create_configurable_model(loss_config, metrics_config):

143

"""Create model with normalized loss and metrics."""

144

model = keras.Sequential([

145

keras.layers.Dense(64, activation='relu', input_dim=10),

146

keras.layers.Dense(1, activation='sigmoid')

147

])

148

149

# Normalize configurations

150

normalized_loss = loss_name(loss_config)

151

normalized_metrics = [metric_name(m) for m in metrics_config]

152

153

model.compile(

154

optimizer='adam',

155

loss=normalized_loss,

156

metrics=normalized_metrics

157

)

158

159

return model

160

161

# Create models with different configuration formats

162

model1 = create_configurable_model("bce", ["acc"])

163

model2 = create_configurable_model("BinaryCrossentropy", ["accuracy", "precision"])

164

model3 = create_configurable_model(

165

keras.losses.BinaryCrossentropy(),

166

[keras.metrics.Accuracy(), keras.metrics.Precision()]

167

)

168

```

169

170

## Implementation Notes

171

172

### CamelCase to snake_case Conversion

173

174

The utilities automatically convert CamelCase class names to snake_case:

175

176

- `BinaryCrossentropy``binary_crossentropy`

177

- `MeanSquaredError``mean_squared_error`

178

- `CategoricalCrossentropy``categorical_crossentropy`

179

180

### Error Handling

181

182

Both functions raise `TypeError` with descriptive messages for invalid inputs:

183

184

```python

185

try:

186

loss_name(123) # Invalid type

187

except TypeError as e:

188

print(e) # "loss must be a string, a function, an instance of keras.losses.Loss..."

189

```

190

191

### Keras Compatibility

192

193

The utilities work with both Keras 2.x and 3.x naming conventions and automatically handle version differences in the underlying Keras API.