CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

losses.mddocs/

Loss Functions

Comprehensive collection of loss functions for classification, regression, and structured prediction tasks. These functions provide differentiable objectives for training neural networks and other machine learning models.

Capabilities

Regression Losses

Mean Squared Error

def l2_loss(predictions, targets=None):
    """
    L2 loss (mean squared error).
    
    Args:
        predictions: Predicted values
        targets: Target values (default: None, uses zeros if not provided)
    
    Returns:
        Scalar loss value
    """

def squared_error(predictions, targets):
    """
    Squared error loss (alias for l2_loss).
    
    Args:
        predictions: Predicted values
        targets: Target values
    
    Returns:
        Scalar loss value
    """

Robust Regression Losses

def huber_loss(predictions, targets, delta=1.0):
    """
    Huber loss for robust regression.
    
    Args:
        predictions: Predicted values
        targets: Target values
        delta: Threshold for switching between squared and linear loss (default: 1.0)
    
    Returns:
        Scalar loss value
    """

def log_cosh(predictions, targets):
    """
    Log-cosh loss for robust regression.
    
    Args:
        predictions: Predicted values
        targets: Target values
    
    Returns:
        Scalar loss value
    """

Distance-Based Losses

def cosine_distance(predictions, targets):
    """
    Cosine distance loss.
    
    Args:
        predictions: Predicted vectors
        targets: Target vectors
    
    Returns:
        Scalar loss value
    """

def cosine_similarity(predictions, targets):
    """
    Cosine similarity (negative cosine distance).
    
    Args:
        predictions: Predicted vectors
        targets: Target vectors
    
    Returns:
        Scalar similarity value
    """

Classification Losses

Cross-Entropy Losses

def softmax_cross_entropy(logits, labels, axis=-1):
    """
    Softmax cross-entropy loss.
    
    Args:
        logits: Predicted logits
        labels: One-hot encoded target labels
        axis: Axis along which to apply softmax (default: -1)
    
    Returns:
        Scalar loss value
    """

def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):
    """
    Softmax cross-entropy loss with integer labels.
    
    Args:
        logits: Predicted logits
        labels: Integer target labels
        axis: Axis along which to apply softmax (default: -1)
    
    Returns:
        Scalar loss value  
    """

def safe_softmax_cross_entropy(logits, labels, axis=-1):
    """
    Numerically stable softmax cross-entropy loss.
    
    Args:
        logits: Predicted logits
        labels: One-hot encoded target labels
        axis: Axis along which to apply softmax (default: -1)
    
    Returns:
        Scalar loss value
    """

def sigmoid_binary_cross_entropy(logits, labels):
    """
    Sigmoid binary cross-entropy loss.
    
    Args:
        logits: Predicted logits
        labels: Binary target labels
    
    Returns:
        Scalar loss value
    """

def poly_loss_cross_entropy(logits, labels, epsilon=2.0):
    """
    PolyLoss cross-entropy for improved tail learning.
    
    Args:
        logits: Predicted logits
        labels: One-hot encoded target labels
        epsilon: Polynomial coefficient (default: 2.0)
    
    Returns:
        Scalar loss value
    """

Margin-Based Losses

def hinge_loss(scores, labels):
    """
    Hinge loss for binary classification.
    
    Args:
        scores: Predicted scores
        labels: Binary labels (+1 or -1)
    
    Returns:
        Scalar loss value
    """

def multiclass_hinge_loss(scores, labels):
    """
    Multiclass hinge loss.
    
    Args:
        scores: Predicted scores for each class
        labels: Integer class labels
    
    Returns:
        Scalar loss value
    """

def perceptron_loss(scores, labels):
    """
    Perceptron loss for binary classification.
    
    Args:
        scores: Predicted scores
        labels: Binary labels (+1 or -1)
    
    Returns:
        Scalar loss value
    """

def multiclass_perceptron_loss(scores, labels):
    """
    Multiclass perceptron loss.
    
    Args:
        scores: Predicted scores for each class
        labels: Integer class labels
    
    Returns:
        Scalar loss value
    """

Focal and Sigmoid Losses

def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):
    """
    Sigmoid focal loss for addressing class imbalance.
    
    Args:
        logits: Predicted logits
        labels: Binary target labels
        alpha: Weighting factor for rare class (default: 0.25)
        gamma: Focusing parameter (default: 2.0)
    
    Returns:
        Scalar loss value
    """

Structured Prediction Losses

Sequence Losses

def ctc_loss(logits, labels, input_lengths, label_lengths, blank=0):
    """
    Connectionist Temporal Classification (CTC) loss.
    
    Args:
        logits: Predicted logits for each time step
        labels: Target sequence labels
        input_lengths: Length of each input sequence
        label_lengths: Length of each target sequence
        blank: Blank token index (default: 0)
    
    Returns:
        Scalar loss value
    """

def ctc_loss_with_forward_probs(logits, labels, input_lengths, label_lengths, blank=0):
    """
    CTC loss that also returns forward probabilities.
    
    Args:
        logits: Predicted logits for each time step
        labels: Target sequence labels
        input_lengths: Length of each input sequence
        label_lengths: Length of each target sequence
        blank: Blank token index (default: 0)
    
    Returns:
        Tuple of (loss, forward_probs)
    """

Ranking and Contrastive Losses

def ranking_softmax_loss(scores, labels):
    """
    Ranking loss using softmax for learning to rank tasks.
    
    Args:  
        scores: Predicted relevance scores
        labels: Target relevance labels
    
    Returns:
        Scalar loss value
    """

def triplet_margin_loss(anchor, positive, negative, margin=1.0):
    """
    Triplet margin loss for metric learning.
    
    Args:
        anchor: Anchor embeddings
        positive: Positive example embeddings
        negative: Negative example embeddings
        margin: Margin parameter (default: 1.0)
    
    Returns:
        Scalar loss value
    """

def ntxent(query, key, temperature=1.0):
    """
    Normalized temperature-scaled cross-entropy loss for contrastive learning.
    
    Args:
        query: Query embeddings
        key: Key embeddings
        temperature: Temperature scaling parameter (default: 1.0)
    
    Returns:
        Scalar loss value
    """

Divergence and Information-Theoretic Losses

KL Divergence

def kl_divergence(log_predictions, targets):
    """
    Kullback-Leibler divergence.
    
    Args:
        log_predictions: Log probabilities of predictions
        targets: Target probability distributions
    
    Returns:
        Scalar divergence value
    """

def kl_divergence_with_log_targets(log_predictions, log_targets):
    """
    KL divergence with log-space targets for numerical stability.
    
    Args:
        log_predictions: Log probabilities of predictions
        log_targets: Log probabilities of targets
    
    Returns:
        Scalar divergence value
    """

def convex_kl_divergence(log_predictions, targets):
    """
    Convex KL divergence (reverse KL).
    
    Args:
        log_predictions: Log probabilities of predictions
        targets: Target probability distributions
    
    Returns:
        Scalar divergence value
    """

Sparsemax and Specialized Losses

Sparsemax Losses

def sparsemax_loss(logits, labels):
    """
    Sparsemax loss for sparse probability distributions.
    
    Args:
        logits: Predicted logits
        labels: Target labels
    
    Returns:
        Scalar loss value
    """

def multiclass_sparsemax_loss(logits, labels):
    """
    Multiclass sparsemax loss.
    
    Args:
        logits: Predicted logits for each class
        labels: Integer class labels
    
    Returns:
        Scalar loss value
    """

Loss Utilities

Label Processing

def smooth_labels(labels, alpha=0.1):
    """
    Apply label smoothing to one-hot labels.
    
    Args:
        labels: One-hot encoded labels
        alpha: Smoothing parameter (default: 0.1)
    
    Returns:
        Smoothed labels
    """

def make_fenchel_young_loss(regularizer):
    """
    Create Fenchel-Young loss from convex regularizer.
    
    Args:
        regularizer: Convex regularization function
    
    Returns:
        Fenchel-Young loss function
    """
    Softmax cross-entropy loss.
    
    Args:
        logits: Unnormalized log probabilities
        labels: One-hot encoded labels or label probabilities
        axis: Axis along which to apply softmax (default: -1)
    
    Returns:
        Scalar loss value
    """

def softmax_cross_entropy_with_integer_labels(logits, labels, axis=-1):
    """
    Softmax cross-entropy with integer labels.
    
    Args:
        logits: Unnormalized log probabilities
        labels: Integer class labels
        axis: Axis along which to apply softmax (default: -1)
    
    Returns:
        Scalar loss value
    """

def safe_softmax_cross_entropy(logits, labels, axis=-1):
    """
    Numerically stable softmax cross-entropy.
    
    Args:
        logits: Unnormalized log probabilities
        labels: One-hot encoded labels or label probabilities
        axis: Axis along which to apply softmax (default: -1)
    
    Returns:
        Scalar loss value
    """

Binary Classification

def sigmoid_binary_cross_entropy(logits, labels):
    """
    Sigmoid binary cross-entropy loss.
    
    Args:
        logits: Unnormalized log probabilities
        labels: Binary labels (0 or 1)
    
    Returns:
        Scalar loss value
    """

Margin-Based Losses

def hinge_loss(scores, labels):
    """
    Hinge loss for binary classification.
    
    Args:
        scores: Prediction scores
        labels: Binary labels (-1 or 1)
    
    Returns:
        Scalar loss value
    """

Focal Loss

def sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0):
    """
    Sigmoid focal loss for addressing class imbalance.
    
    Args:
        logits: Unnormalized log probabilities
        labels: Binary labels
        alpha: Weighting factor for rare class (default: 0.25)
        gamma: Focusing parameter (default: 2.0)
    
    Returns:
        Scalar loss value
    """

Probability Divergences

def kl_divergence(log_predictions, targets):
    """
    Kullback-Leibler divergence.
    
    Args:
        log_predictions: Log probabilities of predictions
        targets: Target probability distribution
    
    Returns:
        Scalar divergence value
    """

def convex_kl_divergence(log_predictions, targets):
    """
    Convex KL divergence (targets * log(targets/predictions)).
    
    Args:
        log_predictions: Log probabilities of predictions
        targets: Target probability distribution
    
    Returns:
        Scalar divergence value
    """

Structured Losses

CTC Loss

def ctc_loss(logits, logit_paddings, labels, label_paddings):
    """
    Connectionist Temporal Classification (CTC) loss.
    
    Args:
        logits: Log probabilities over vocabulary
        logit_paddings: Padding mask for logits
        labels: Target label sequences
        label_paddings: Padding mask for labels
    
    Returns:
        Scalar CTC loss value
    """

def ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings):
    """
    CTC loss with forward probabilities for additional insights.
    
    Args:
        logits: Log probabilities over vocabulary
        logit_paddings: Padding mask for logits
        labels: Target label sequences
        label_paddings: Padding mask for labels
    
    Returns:
        Tuple of (loss, forward_probs)
    """

Self-Supervised Losses

Contrastive Learning

def ntxent(query_features, key_features, temperature=1.0):
    """
    Normalized Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning.
    
    Args:
        query_features: Query feature vectors
        key_features: Key feature vectors  
        temperature: Temperature scaling parameter (default: 1.0)
    
    Returns:
        Scalar contrastive loss value
    """

Label Processing

def smooth_labels(labels, alpha):
    """
    Apply label smoothing to one-hot labels.
    
    Args:
        labels: One-hot encoded labels
        alpha: Smoothing parameter (0 = no smoothing, 1 = uniform)
    
    Returns:
        Smoothed label distribution
    """

Usage Examples

Basic Regression

import optax
import jax.numpy as jnp

# Predictions and targets
predictions = jnp.array([1.0, 2.0, 3.0])
targets = jnp.array([1.1, 1.9, 3.2])

# Compute losses
mse_loss = optax.l2_loss(predictions, targets)
huber_loss_val = optax.huber_loss(predictions, targets, delta=1.0)

Classification Setup

# Multi-class classification
logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 3.0, 0.5]])
one_hot_labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
integer_labels = jnp.array([0, 1])

# Cross-entropy losses
ce_loss = optax.softmax_cross_entropy(logits, one_hot_labels)
ce_int_loss = optax.softmax_cross_entropy_with_integer_labels(logits, integer_labels)

# Binary classification
binary_logits = jnp.array([0.5, -1.2, 2.1])
binary_labels = jnp.array([1.0, 0.0, 1.0])
binary_loss = optax.sigmoid_binary_cross_entropy(binary_logits, binary_labels)

Training Loop Integration

import jax

def compute_loss(params, batch_x, batch_y):
    """Compute loss for a batch."""
    predictions = model_fn(params, batch_x)
    return optax.softmax_cross_entropy_with_integer_labels(predictions, batch_y)

def train_step(params, opt_state, batch_x, batch_y):
    """Single training step."""
    # Compute loss and gradients
    loss_val, grads = jax.value_and_grad(compute_loss)(params, batch_x, batch_y)
    
    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss_val

Advanced Loss Combinations

def combined_loss(predictions, targets, params):
    """Combine multiple loss terms."""
    # Main task loss
    task_loss = optax.softmax_cross_entropy(predictions, targets)
    
    # Regularization loss
    l2_reg = sum(optax.l2_loss(p, jnp.zeros_like(p)) for p in jax.tree_leaves(params))
    
    # Total loss
    return task_loss + 1e-4 * l2_reg

# With label smoothing
smoothed_labels = optax.smooth_labels(one_hot_labels, alpha=0.1)
smooth_loss = optax.softmax_cross_entropy(logits, smoothed_labels)

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json