Multi-backend deep learning framework providing a unified API for building and training neural networks across JAX, TensorFlow, PyTorch, and OpenVINO backends
—
Comprehensive collection of loss functions for training neural networks and metrics for evaluation, covering classification, regression, and specialized tasks with both class-based and function-based APIs.
Loss functions designed for classification tasks including binary, multiclass, and specialized classification scenarios.
class BinaryCrossentropy:
"""
Binary cross-entropy loss for binary classification.
Args:
from_logits (bool): Whether input is logits or probabilities
label_smoothing (float): Label smoothing factor
axis (int): Axis along which to compute loss
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, from_logits=False, label_smoothing=0.0, axis=-1, **kwargs): ...
class CategoricalCrossentropy:
"""
Categorical cross-entropy loss for multiclass classification.
Args:
from_logits (bool): Whether input is logits or probabilities
label_smoothing (float): Label smoothing factor
axis (int): Axis along which to compute loss
"""
def __init__(self, from_logits=False, label_smoothing=0.0, axis=-1, **kwargs): ...
class SparseCategoricalCrossentropy:
"""
Sparse categorical cross-entropy for integer labels.
Args:
from_logits (bool): Whether input is logits or probabilities
ignore_class (int, optional): Class index to ignore
axis (int): Axis along which to compute loss
"""
def __init__(self, from_logits=False, ignore_class=None, axis=-1, **kwargs): ...
class BinaryFocalCrossentropy:
"""
Binary focal loss for addressing class imbalance.
Args:
alpha (float): Weighting factor for rare class
gamma (float): Focusing parameter
from_logits (bool): Whether input is logits or probabilities
label_smoothing (float): Label smoothing factor
"""
def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, label_smoothing=0.0, **kwargs): ...
class CategoricalFocalCrossentropy:
"""
Categorical focal loss for multiclass imbalanced datasets.
Args:
alpha (float): Weighting factor
gamma (float): Focusing parameter
from_logits (bool): Whether input is logits or probabilities
"""
def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, **kwargs): ...
# Function equivalents
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1): ...
def categorical_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1): ...
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, ignore_class=None, axis=-1): ...Loss functions for continuous value prediction tasks with various robustness properties.
class MeanSquaredError:
"""
Mean squared error loss for regression.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class MeanAbsoluteError:
"""
Mean absolute error loss for regression.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class MeanAbsolutePercentageError:
"""
Mean absolute percentage error for regression.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class MeanSquaredLogarithmicError:
"""
Mean squared logarithmic error for regression.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class Huber:
"""
Huber loss for robust regression.
Args:
delta (float): Point where loss changes from quadratic to linear
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, delta=1.0, **kwargs): ...
class LogCosh:
"""
Log-cosh loss for regression.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
# Function equivalents
def mean_squared_error(y_true, y_pred): ...
def mean_absolute_error(y_true, y_pred): ...
def mean_absolute_percentage_error(y_true, y_pred): ...
def huber(y_true, y_pred, delta=1.0): ...Loss functions for specific tasks including ranking, sequence modeling, and segmentation.
class Hinge:
"""
Hinge loss for maximum-margin classification.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class SquaredHinge:
"""Squared hinge loss for maximum-margin classification."""
def __init__(self, **kwargs): ...
class CategoricalHinge:
"""Categorical hinge loss for multiclass classification."""
def __init__(self, **kwargs): ...
class KLDivergence:
"""
Kullback-Leibler divergence loss.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class Poisson:
"""
Poisson loss for count data.
Args:
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, **kwargs): ...
class CosineSimilarity:
"""
Cosine similarity loss.
Args:
axis (int): Axis along which to compute cosine similarity
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, axis=-1, **kwargs): ...
class Dice:
"""
Dice loss for segmentation tasks.
Args:
axis (int or tuple, optional): Axis to compute dice over
reduction (str): Type of reduction to apply
name (str): Name of the loss
"""
def __init__(self, axis=None, **kwargs): ...
class Tversky:
"""
Tversky loss for segmentation with adjustable precision/recall balance.
Args:
alpha (float): Weight for false positives
beta (float): Weight for false negatives
axis (int or tuple, optional): Axis to compute over
"""
def __init__(self, alpha=0.5, beta=0.5, axis=None, **kwargs): ...
class CTC:
"""
Connectionist Temporal Classification loss for sequence labeling.
Args:
logits_time_major (bool): Whether logits are time-major
blank_index (int, optional): Index of blank label
reduction (str): Type of reduction to apply
"""
def __init__(self, logits_time_major=False, blank_index=None, **kwargs): ...Metrics for evaluating classification model performance including accuracy variants and confusion matrix based metrics.
class Accuracy:
"""
Generic accuracy metric.
Args:
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, name='accuracy', dtype=None, **kwargs): ...
class BinaryAccuracy:
"""
Binary classification accuracy.
Args:
threshold (float): Decision threshold
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, threshold=0.5, name='binary_accuracy', dtype=None, **kwargs): ...
class CategoricalAccuracy:
"""
Categorical accuracy for one-hot encoded labels.
Args:
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, name='categorical_accuracy', dtype=None, **kwargs): ...
class SparseCategoricalAccuracy:
"""
Categorical accuracy for integer labels.
Args:
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, name='sparse_categorical_accuracy', dtype=None, **kwargs): ...
class TopKCategoricalAccuracy:
"""
Top-k categorical accuracy.
Args:
k (int): Number of top predictions to consider
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None, **kwargs): ...
class Precision:
"""
Precision metric.
Args:
thresholds (list, optional): List of thresholds for multi-threshold precision
top_k (int, optional): Top-k precision
class_id (int, optional): Class to compute precision for
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, thresholds=None, top_k=None, class_id=None, name='precision', dtype=None, **kwargs): ...
class Recall:
"""
Recall metric.
Args:
thresholds (list, optional): List of thresholds for multi-threshold recall
top_k (int, optional): Top-k recall
class_id (int, optional): Class to compute recall for
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, thresholds=None, top_k=None, class_id=None, name='recall', dtype=None, **kwargs): ...
class F1Score:
"""
F1 score metric.
Args:
average (str, optional): Averaging strategy ('micro', 'macro', 'weighted', None)
threshold (float, optional): Decision threshold for binary classification
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, average=None, threshold=None, name='f1_score', dtype=None, **kwargs): ...
class AUC:
"""
Area Under the ROC Curve metric.
Args:
num_thresholds (int): Number of thresholds for ROC curve
curve (str): Type of curve ('ROC' or 'PR')
summation_method (str): Method for approximating AUC
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, num_thresholds=200, curve='ROC', summation_method='interpolation',
name='auc', dtype=None, **kwargs): ...Metrics for evaluating regression model performance.
class MeanSquaredError:
"""Mean squared error metric for regression."""
def __init__(self, name='mean_squared_error', dtype=None, **kwargs): ...
class RootMeanSquaredError:
"""Root mean squared error metric for regression."""
def __init__(self, name='root_mean_squared_error', dtype=None, **kwargs): ...
class MeanAbsoluteError:
"""Mean absolute error metric for regression."""
def __init__(self, name='mean_absolute_error', dtype=None, **kwargs): ...
class MeanAbsolutePercentageError:
"""Mean absolute percentage error metric for regression."""
def __init__(self, name='mean_absolute_percentage_error', dtype=None, **kwargs): ...
class R2Score:
"""
R² (coefficient of determination) metric.
Args:
class_aggregation (str): How to aggregate multiclass R²
num_regressors (int, optional): Number of regressors for adjusted R²
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, class_aggregation='uniform_average', num_regressors=0,
name='r2_score', dtype=None, **kwargs): ...
class CosineSimilarity:
"""
Cosine similarity metric.
Args:
axis (int): Axis along which to compute cosine similarity
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, axis=-1, name='cosine_similarity', dtype=None, **kwargs): ...Metrics for evaluating image segmentation and pixel-wise classification tasks.
class IoU:
"""
Intersection over Union (Jaccard Index) metric.
Args:
num_classes (int): Number of classes
target_class_ids (list, optional): Specific classes to compute IoU for
threshold (float, optional): Threshold for binary predictions
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, num_classes, target_class_ids=None, threshold=None,
name='iou', dtype=None, **kwargs): ...
class MeanIoU:
"""
Mean Intersection over Union metric.
Args:
num_classes (int): Number of classes
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, num_classes, name='mean_iou', dtype=None, **kwargs): ...
class BinaryIoU:
"""
Binary Intersection over Union metric.
Args:
target_class_ids (list, optional): Target class IDs
threshold (float): Decision threshold
name (str): Name of the metric
dtype (str): Data type for metric computation
"""
def __init__(self, target_class_ids=None, threshold=0.5, name='binary_iou', dtype=None, **kwargs): ...Functions for metric and loss management.
# Loss utilities
def get(identifier):
"""Get loss function by name or return callable."""
def serialize(loss):
"""Serialize loss to JSON-serializable dict."""
def deserialize(config, custom_objects=None):
"""Deserialize loss from config dict."""
# Metric utilities
def get(identifier):
"""Get metric by name or return callable."""
def serialize(metric):
"""Serialize metric to JSON-serializable dict."""
def deserialize(config, custom_objects=None):
"""Deserialize metric from config dict."""import keras
from keras import layers, losses, metrics
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(10, activation='softmax')
])
# Using string identifiers
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Using class instances for more control
model.compile(
optimizer='adam',
loss=losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[
metrics.SparseCategoricalAccuracy(),
metrics.TopKCategoricalAccuracy(k=3)
]
)import keras
from keras import layers, losses, metrics
# Define inputs
inputs = keras.Input(shape=(784,))
x = layers.Dense(64, activation='relu')(inputs)
# Multiple outputs
classification_output = layers.Dense(10, activation='softmax', name='classification')(x)
regression_output = layers.Dense(1, name='regression')(x)
model = keras.Model(inputs=inputs, outputs=[classification_output, regression_output])
# Different losses for different outputs
model.compile(
optimizer='adam',
loss={
'classification': losses.SparseCategoricalCrossentropy(),
'regression': losses.MeanSquaredError()
},
metrics={
'classification': [metrics.SparseCategoricalAccuracy(), metrics.F1Score()],
'regression': [metrics.MeanAbsoluteError(), metrics.R2Score()]
},
loss_weights={'classification': 1.0, 'regression': 0.5}
)import keras
from keras import ops
def focal_loss(alpha=0.25, gamma=2.0):
def loss_fn(y_true, y_pred):
# Convert to probabilities if logits
y_pred = ops.sigmoid(y_pred)
# Compute focal loss
pt = ops.where(y_true == 1, y_pred, 1 - y_pred)
alpha_t = ops.where(y_true == 1, alpha, 1 - alpha)
focal_weight = alpha_t * ops.power(1 - pt, gamma)
bce = -ops.log(pt + 1e-8)
focal = focal_weight * bce
return ops.mean(focal)
return loss_fn
# Use custom loss
model.compile(
optimizer='adam',
loss=focal_loss(alpha=0.25, gamma=2.0),
metrics=['accuracy']
)import keras
from keras import ops
class F2Score(keras.metrics.Metric):
def __init__(self, name='f2_score', **kwargs):
super().__init__(name=name, **kwargs)
self.precision = keras.metrics.Precision()
self.recall = keras.metrics.Recall()
def update_state(self, y_true, y_pred, sample_weight=None):
self.precision.update_state(y_true, y_pred, sample_weight)
self.recall.update_state(y_true, y_pred, sample_weight)
def result(self):
p = self.precision.result()
r = self.recall.result()
return 5 * p * r / (4 * p + r + 1e-8)
def reset_state(self):
self.precision.reset_state()
self.recall.reset_state()
# Use custom metric
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=[F2Score(), 'accuracy']
)Install with Tessl CLI
npx tessl i tessl/pypi-keras-nightly