or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

second-order.mddocs/

0

# Second-Order Optimization Utilities

1

2

Utilities for second-order optimization methods including Hessian computations, Fisher information, and Hessian-vector products. These functions support advanced optimization techniques that utilize curvature information.

3

4

## Capabilities

5

6

### Hessian Computations

7

8

```python { .api }

9

def hessian_diag(fun):

10

"""

11

Compute diagonal elements of the Hessian matrix.

12

13

Args:

14

fun: Function for which to compute the Hessian diagonal

15

16

Returns:

17

Function that computes diagonal Hessian elements

18

"""

19

```

20

21

### Fisher Information

22

23

```python { .api }

24

def fisher_diag(log_likelihood):

25

"""

26

Compute diagonal Fisher information matrix.

27

28

Args:

29

log_likelihood: Log-likelihood function

30

31

Returns:

32

Function that computes diagonal Fisher information

33

"""

34

```

35

36

### Hessian-Vector Products

37

38

```python { .api }

39

def hvp(fun, primals, tangents):

40

"""

41

Compute Hessian-vector product efficiently.

42

43

Args:

44

fun: Function for Hessian computation

45

primals: Point at which to evaluate Hessian

46

tangents: Vector for Hessian-vector product

47

48

Returns:

49

Result of Hessian-vector product

50

"""

51

```

52

53

## Usage Examples

54

55

```python

56

import optax

57

import jax

58

import jax.numpy as jnp

59

60

# Define a quadratic function

61

def quadratic_loss(params, x, y):

62

pred = params @ x

63

return 0.5 * (pred - y)**2

64

65

# Compute diagonal Hessian

66

hess_diag_fn = optax.second_order.hessian_diag(

67

lambda p: quadratic_loss(p, x_data, y_data)

68

)

69

params = jnp.array([1.0, 2.0])

70

hess_diag = hess_diag_fn(params)

71

72

# Compute Hessian-vector product

73

tangent = jnp.array([0.1, 0.2])

74

hvp_result = optax.second_order.hvp(

75

lambda p: quadratic_loss(p, x_data, y_data),

76

params,

77

tangent

78

)

79

```

80

81

## Import

82

83

```python

84

import optax.second_order

85

# or

86

from optax.second_order import hessian_diag, fisher_diag, hvp

87

```