CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

second-order.mddocs/

Second-Order Optimization Utilities

Utilities for second-order optimization methods including Hessian computations, Fisher information, and Hessian-vector products. These functions support advanced optimization techniques that utilize curvature information.

Capabilities

Hessian Computations

def hessian_diag(fun):
    """
    Compute diagonal elements of the Hessian matrix.
    
    Args:
        fun: Function for which to compute the Hessian diagonal
    
    Returns:
        Function that computes diagonal Hessian elements
    """

Fisher Information

def fisher_diag(log_likelihood):
    """
    Compute diagonal Fisher information matrix.
    
    Args:
        log_likelihood: Log-likelihood function
    
    Returns:
        Function that computes diagonal Fisher information
    """

Hessian-Vector Products

def hvp(fun, primals, tangents):
    """
    Compute Hessian-vector product efficiently.
    
    Args:
        fun: Function for Hessian computation
        primals: Point at which to evaluate Hessian
        tangents: Vector for Hessian-vector product
    
    Returns:
        Result of Hessian-vector product
    """

Usage Examples

import optax
import jax
import jax.numpy as jnp

# Define a quadratic function
def quadratic_loss(params, x, y):
    pred = params @ x
    return 0.5 * (pred - y)**2

# Compute diagonal Hessian
hess_diag_fn = optax.second_order.hessian_diag(
    lambda p: quadratic_loss(p, x_data, y_data)
)
params = jnp.array([1.0, 2.0])
hess_diag = hess_diag_fn(params)

# Compute Hessian-vector product
tangent = jnp.array([0.1, 0.2])
hvp_result = optax.second_order.hvp(
    lambda p: quadratic_loss(p, x_data, y_data),
    params, 
    tangent
)

Import

import optax.second_order
# or
from optax.second_order import hessian_diag, fisher_diag, hvp

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json