A gradient processing and optimization library in JAX
—
Utilities for second-order optimization methods including Hessian computations, Fisher information, and Hessian-vector products. These functions support advanced optimization techniques that utilize curvature information.
def hessian_diag(fun):
"""
Compute diagonal elements of the Hessian matrix.
Args:
fun: Function for which to compute the Hessian diagonal
Returns:
Function that computes diagonal Hessian elements
"""def fisher_diag(log_likelihood):
"""
Compute diagonal Fisher information matrix.
Args:
log_likelihood: Log-likelihood function
Returns:
Function that computes diagonal Fisher information
"""def hvp(fun, primals, tangents):
"""
Compute Hessian-vector product efficiently.
Args:
fun: Function for Hessian computation
primals: Point at which to evaluate Hessian
tangents: Vector for Hessian-vector product
Returns:
Result of Hessian-vector product
"""import optax
import jax
import jax.numpy as jnp
# Define a quadratic function
def quadratic_loss(params, x, y):
pred = params @ x
return 0.5 * (pred - y)**2
# Compute diagonal Hessian
hess_diag_fn = optax.second_order.hessian_diag(
lambda p: quadratic_loss(p, x_data, y_data)
)
params = jnp.array([1.0, 2.0])
hess_diag = hess_diag_fn(params)
# Compute Hessian-vector product
tangent = jnp.array([0.1, 0.2])
hvp_result = optax.second_order.hvp(
lambda p: quadratic_loss(p, x_data, y_data),
params,
tangent
)import optax.second_order
# or
from optax.second_order import hessian_diag, fisher_diag, hvpInstall with Tessl CLI
npx tessl i tessl/pypi-optax