0
# Second-Order Optimization Utilities
1
2
Utilities for second-order optimization methods including Hessian computations, Fisher information, and Hessian-vector products. These functions support advanced optimization techniques that utilize curvature information.
3
4
## Capabilities
5
6
### Hessian Computations
7
8
```python { .api }
9
def hessian_diag(fun):
10
"""
11
Compute diagonal elements of the Hessian matrix.
12
13
Args:
14
fun: Function for which to compute the Hessian diagonal
15
16
Returns:
17
Function that computes diagonal Hessian elements
18
"""
19
```
20
21
### Fisher Information
22
23
```python { .api }
24
def fisher_diag(log_likelihood):
25
"""
26
Compute diagonal Fisher information matrix.
27
28
Args:
29
log_likelihood: Log-likelihood function
30
31
Returns:
32
Function that computes diagonal Fisher information
33
"""
34
```
35
36
### Hessian-Vector Products
37
38
```python { .api }
39
def hvp(fun, primals, tangents):
40
"""
41
Compute Hessian-vector product efficiently.
42
43
Args:
44
fun: Function for Hessian computation
45
primals: Point at which to evaluate Hessian
46
tangents: Vector for Hessian-vector product
47
48
Returns:
49
Result of Hessian-vector product
50
"""
51
```
52
53
## Usage Examples
54
55
```python
56
import optax
57
import jax
58
import jax.numpy as jnp
59
60
# Define a quadratic function
61
def quadratic_loss(params, x, y):
62
pred = params @ x
63
return 0.5 * (pred - y)**2
64
65
# Compute diagonal Hessian
66
hess_diag_fn = optax.second_order.hessian_diag(
67
lambda p: quadratic_loss(p, x_data, y_data)
68
)
69
params = jnp.array([1.0, 2.0])
70
hess_diag = hess_diag_fn(params)
71
72
# Compute Hessian-vector product
73
tangent = jnp.array([0.1, 0.2])
74
hvp_result = optax.second_order.hvp(
75
lambda p: quadratic_loss(p, x_data, y_data),
76
params,
77
tangent
78
)
79
```
80
81
## Import
82
83
```python
84
import optax.second_order
85
# or
86
from optax.second_order import hessian_diag, fisher_diag, hvp
87
```