A gradient processing and optimization library in JAX
—
Comprehensive collection of loss functions for classification, regression, and structured prediction tasks. These functions provide differentiable objectives for training neural networks and other machine learning models.
def l2_loss(predictions, targets=None):
"""
L2 loss (mean squared error).
Args:
predictions: Predicted values
targets: Target values (default: None, uses zeros if not provided)
Returns:
Scalar loss value
"""
def squared_error(predictions, targets):
"""
Squared error loss (alias for l2_loss).
Args:
predictions: Predicted values
targets: Target values
Returns:
Scalar loss value
"""def huber_loss(predictions, targets, delta=1.0):
"""
Huber loss for robust regression.
Args:
predictions: Predicted values
targets: Target values
delta: Threshold for switching between squared and linear loss (default: 1.0)
Returns:
Scalar loss value
"""
def log_cosh(predictions, targets):
"""
Log-cosh loss for robust regression.
Args:
predictions: Predicted values
targets: Target values
Returns:
Scalar loss value
"""def cosine_distance(predictions, targets):
"""
Cosine distance loss.
Args:
predictions: Predicted vectors
targets: Target vectors
Returns:
Scalar loss value
"""
def cosine_similarity(predictions, targets):
"""
Cosine similarity (negative cosine distance).
Args:
predictions: Predicted vectors
targets: Target vectors
Returns:
Scalar similarity value
"""def softmax_cross_entropy(logits, labels, axis=-1):
"""
Softmax cross-entropy loss.
Args:
logits: Predicted logits
labels: One-hot encoded target labels
axis: Axis along which to apply softmax (default: -1)
Returns:
Scalar loss value
"""
def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):
"""
Softmax cross-entropy loss with integer labels.
Args:
logits: Predicted logits
labels: Integer target labels
axis: Axis along which to apply softmax (default: -1)
Returns:
Scalar loss value
"""
def safe_softmax_cross_entropy(logits, labels, axis=-1):
"""
Numerically stable softmax cross-entropy loss.
Args:
logits: Predicted logits
labels: One-hot encoded target labels
axis: Axis along which to apply softmax (default: -1)
Returns:
Scalar loss value
"""
def sigmoid_binary_cross_entropy(logits, labels):
"""
Sigmoid binary cross-entropy loss.
Args:
logits: Predicted logits
labels: Binary target labels
Returns:
Scalar loss value
"""
def poly_loss_cross_entropy(logits, labels, epsilon=2.0):
"""
PolyLoss cross-entropy for improved tail learning.
Args:
logits: Predicted logits
labels: One-hot encoded target labels
epsilon: Polynomial coefficient (default: 2.0)
Returns:
Scalar loss value
"""def hinge_loss(scores, labels):
"""
Hinge loss for binary classification.
Args:
scores: Predicted scores
labels: Binary labels (+1 or -1)
Returns:
Scalar loss value
"""
def multiclass_hinge_loss(scores, labels):
"""
Multiclass hinge loss.
Args:
scores: Predicted scores for each class
labels: Integer class labels
Returns:
Scalar loss value
"""
def perceptron_loss(scores, labels):
"""
Perceptron loss for binary classification.
Args:
scores: Predicted scores
labels: Binary labels (+1 or -1)
Returns:
Scalar loss value
"""
def multiclass_perceptron_loss(scores, labels):
"""
Multiclass perceptron loss.
Args:
scores: Predicted scores for each class
labels: Integer class labels
Returns:
Scalar loss value
"""def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):
"""
Sigmoid focal loss for addressing class imbalance.
Args:
logits: Predicted logits
labels: Binary target labels
alpha: Weighting factor for rare class (default: 0.25)
gamma: Focusing parameter (default: 2.0)
Returns:
Scalar loss value
"""def ctc_loss(logits, labels, input_lengths, label_lengths, blank=0):
"""
Connectionist Temporal Classification (CTC) loss.
Args:
logits: Predicted logits for each time step
labels: Target sequence labels
input_lengths: Length of each input sequence
label_lengths: Length of each target sequence
blank: Blank token index (default: 0)
Returns:
Scalar loss value
"""
def ctc_loss_with_forward_probs(logits, labels, input_lengths, label_lengths, blank=0):
"""
CTC loss that also returns forward probabilities.
Args:
logits: Predicted logits for each time step
labels: Target sequence labels
input_lengths: Length of each input sequence
label_lengths: Length of each target sequence
blank: Blank token index (default: 0)
Returns:
Tuple of (loss, forward_probs)
"""def ranking_softmax_loss(scores, labels):
"""
Ranking loss using softmax for learning to rank tasks.
Args:
scores: Predicted relevance scores
labels: Target relevance labels
Returns:
Scalar loss value
"""
def triplet_margin_loss(anchor, positive, negative, margin=1.0):
"""
Triplet margin loss for metric learning.
Args:
anchor: Anchor embeddings
positive: Positive example embeddings
negative: Negative example embeddings
margin: Margin parameter (default: 1.0)
Returns:
Scalar loss value
"""
def ntxent(query, key, temperature=1.0):
"""
Normalized temperature-scaled cross-entropy loss for contrastive learning.
Args:
query: Query embeddings
key: Key embeddings
temperature: Temperature scaling parameter (default: 1.0)
Returns:
Scalar loss value
"""def kl_divergence(log_predictions, targets):
"""
Kullback-Leibler divergence.
Args:
log_predictions: Log probabilities of predictions
targets: Target probability distributions
Returns:
Scalar divergence value
"""
def kl_divergence_with_log_targets(log_predictions, log_targets):
"""
KL divergence with log-space targets for numerical stability.
Args:
log_predictions: Log probabilities of predictions
log_targets: Log probabilities of targets
Returns:
Scalar divergence value
"""
def convex_kl_divergence(log_predictions, targets):
"""
Convex KL divergence (reverse KL).
Args:
log_predictions: Log probabilities of predictions
targets: Target probability distributions
Returns:
Scalar divergence value
"""def sparsemax_loss(logits, labels):
"""
Sparsemax loss for sparse probability distributions.
Args:
logits: Predicted logits
labels: Target labels
Returns:
Scalar loss value
"""
def multiclass_sparsemax_loss(logits, labels):
"""
Multiclass sparsemax loss.
Args:
logits: Predicted logits for each class
labels: Integer class labels
Returns:
Scalar loss value
"""def smooth_labels(labels, alpha=0.1):
"""
Apply label smoothing to one-hot labels.
Args:
labels: One-hot encoded labels
alpha: Smoothing parameter (default: 0.1)
Returns:
Smoothed labels
"""
def make_fenchel_young_loss(regularizer):
"""
Create Fenchel-Young loss from convex regularizer.
Args:
regularizer: Convex regularization function
Returns:
Fenchel-Young loss function
"""
Softmax cross-entropy loss.
Args:
logits: Unnormalized log probabilities
labels: One-hot encoded labels or label probabilities
axis: Axis along which to apply softmax (default: -1)
Returns:
Scalar loss value
"""
def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):
"""
Softmax cross-entropy with integer labels.
Args:
logits: Unnormalized log probabilities
labels: Integer class labels
axis: Axis along which to apply softmax (default: -1)
Returns:
Scalar loss value
"""
def safe_softmax_cross_entropy(logits, labels, axis=-1):
"""
Numerically stable softmax cross-entropy.
Args:
logits: Unnormalized log probabilities
labels: One-hot encoded labels or label probabilities
axis: Axis along which to apply softmax (default: -1)
Returns:
Scalar loss value
"""def sigmoid_binary_cross_entropy(logits, labels):
"""
Sigmoid binary cross-entropy loss.
Args:
logits: Unnormalized log probabilities
labels: Binary labels (0 or 1)
Returns:
Scalar loss value
"""def hinge_loss(scores, labels):
"""
Hinge loss for binary classification.
Args:
scores: Prediction scores
labels: Binary labels (-1 or 1)
Returns:
Scalar loss value
"""def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):
"""
Sigmoid focal loss for addressing class imbalance.
Args:
logits: Unnormalized log probabilities
labels: Binary labels
alpha: Weighting factor for rare class (default: 0.25)
gamma: Focusing parameter (default: 2.0)
Returns:
Scalar loss value
"""def kl_divergence(log_predictions, targets):
"""
Kullback-Leibler divergence.
Args:
log_predictions: Log probabilities of predictions
targets: Target probability distribution
Returns:
Scalar divergence value
"""
def convex_kl_divergence(log_predictions, targets):
"""
Convex KL divergence (targets * log(targets/predictions)).
Args:
log_predictions: Log probabilities of predictions
targets: Target probability distribution
Returns:
Scalar divergence value
"""def ctc_loss(logits, logit_paddings, labels, label_paddings):
"""
Connectionist Temporal Classification (CTC) loss.
Args:
logits: Log probabilities over vocabulary
logit_paddings: Padding mask for logits
labels: Target label sequences
label_paddings: Padding mask for labels
Returns:
Scalar CTC loss value
"""
def ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings):
"""
CTC loss with forward probabilities for additional insights.
Args:
logits: Log probabilities over vocabulary
logit_paddings: Padding mask for logits
labels: Target label sequences
label_paddings: Padding mask for labels
Returns:
Tuple of (loss, forward_probs)
"""def ntxent(query_features, key_features, temperature=1.0):
"""
Normalized Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning.
Args:
query_features: Query feature vectors
key_features: Key feature vectors
temperature: Temperature scaling parameter (default: 1.0)
Returns:
Scalar contrastive loss value
"""def smooth_labels(labels, alpha):
"""
Apply label smoothing to one-hot labels.
Args:
labels: One-hot encoded labels
alpha: Smoothing parameter (0 = no smoothing, 1 = uniform)
Returns:
Smoothed label distribution
"""import optax
import jax.numpy as jnp
# Predictions and targets
predictions = jnp.array([1.0, 2.0, 3.0])
targets = jnp.array([1.1, 1.9, 3.2])
# Compute losses
mse_loss = optax.l2_loss(predictions, targets)
huber_loss_val = optax.huber_loss(predictions, targets, delta=1.0)# Multi-class classification
logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 3.0, 0.5]])
one_hot_labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
integer_labels = jnp.array([0, 1])
# Cross-entropy losses
ce_loss = optax.softmax_cross_entropy(logits, one_hot_labels)
ce_int_loss = optax.softmax_cross_entropy_with_integer_labels(logits, integer_labels)
# Binary classification
binary_logits = jnp.array([0.5, -1.2, 2.1])
binary_labels = jnp.array([1.0, 0.0, 1.0])
binary_loss = optax.sigmoid_binary_cross_entropy(binary_logits, binary_labels)import jax
def compute_loss(params, batch_x, batch_y):
"""Compute loss for a batch."""
predictions = model_fn(params, batch_x)
return optax.softmax_cross_entropy_with_integer_labels(predictions, batch_y)
def train_step(params, opt_state, batch_x, batch_y):
"""Single training step."""
# Compute loss and gradients
loss_val, grads = jax.value_and_grad(compute_loss)(params, batch_x, batch_y)
# Update parameters
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_valdef combined_loss(predictions, targets, params):
"""Combine multiple loss terms."""
# Main task loss
task_loss = optax.softmax_cross_entropy(predictions, targets)
# Regularization loss
l2_reg = sum(optax.l2_loss(p, jnp.zeros_like(p)) for p in jax.tree_leaves(params))
# Total loss
return task_loss + 1e-4 * l2_reg
# With label smoothing
smoothed_labels = optax.smooth_labels(one_hot_labels, alpha=0.1)
smooth_loss = optax.softmax_cross_entropy(logits, smoothed_labels)Install with Tessl CLI
npx tessl i tessl/pypi-optax