or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

optimizers.mddocs/

0

# Core Optimizers

1

2

Popular optimization algorithms that are ready for immediate use in training loops. These optimizers combine multiple gradient transformations into complete optimization strategies with sensible defaults.

3

4

## Capabilities

5

6

### Adam Optimizer

7

8

The Adam optimizer with optional Nesterov momentum. Combines adaptive learning rates with momentum for efficient optimization across a wide range of problems.

9

10

```python { .api }

11

def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, eps_root=0.0, mu_dtype=None, *, nesterov=False):

12

"""

13

Adam optimizer.

14

15

Args:

16

learning_rate: Learning rate or schedule

17

b1: Exponential decay rate for first moment estimates (default: 0.9)

18

b2: Exponential decay rate for second moment estimates (default: 0.999)

19

eps: Small constant for numerical stability (default: 1e-8)

20

eps_root: Small constant for numerical stability in denominator (default: 0.0)

21

mu_dtype: Optional dtype for momentum accumulator (default: None)

22

nesterov: Whether to use Nesterov momentum (default: False)

23

24

Returns:

25

GradientTransformationExtraArgs

26

"""

27

```

28

29

### AdamW Optimizer

30

31

Adam optimizer with decoupled weight decay. Separates weight decay from gradient-based updates for better generalization.

32

33

```python { .api }

34

def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False):

35

"""

36

AdamW optimizer with decoupled weight decay.

37

38

Args:

39

learning_rate: Learning rate or schedule

40

b1: Exponential decay rate for first moment estimates (default: 0.9)

41

b2: Exponential decay rate for second moment estimates (default: 0.999)

42

eps: Small constant for numerical stability (default: 1e-8)

43

weight_decay: Weight decay coefficient (default: 1e-4)

44

nesterov: Whether to use Nesterov momentum (default: False)

45

46

Returns:

47

GradientTransformation

48

"""

49

```

50

51

### Stochastic Gradient Descent

52

53

Classic SGD optimizer with optional momentum and Nesterov acceleration.

54

55

```python { .api }

56

def sgd(learning_rate, momentum=None, nesterov=False):

57

"""

58

Stochastic gradient descent optimizer.

59

60

Args:

61

learning_rate: Learning rate or schedule

62

momentum: Momentum coefficient (default: None for no momentum)

63

nesterov: Whether to use Nesterov momentum (default: False)

64

65

Returns:

66

GradientTransformation

67

"""

68

```

69

70

### RMSprop Optimizer

71

72

RMSprop optimizer with adaptive learning rates based on recent gradient magnitudes.

73

74

```python { .api }

75

def rmsprop(learning_rate, decay=0.9, eps=1e-8):

76

"""

77

RMSprop optimizer.

78

79

Args:

80

learning_rate: Learning rate or schedule

81

decay: Decay rate for moving average of squared gradients (default: 0.9)

82

eps: Small constant for numerical stability (default: 1e-8)

83

84

Returns:

85

GradientTransformation

86

"""

87

```

88

89

### Adagrad Optimizer

90

91

Adagrad optimizer with adaptive learning rates that decrease over time.

92

93

```python { .api }

94

def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7):

95

"""

96

Adagrad optimizer.

97

98

Args:

99

learning_rate: Learning rate or schedule

100

initial_accumulator_value: Initial value for accumulator (default: 0.1)

101

eps: Small constant for numerical stability (default: 1e-7)

102

103

Returns:

104

GradientTransformation

105

"""

106

```

107

108

### Adadelta Optimizer

109

110

Adadelta optimizer that adapts learning rates based on a moving window of gradient updates.

111

112

```python { .api }

113

def adadelta(learning_rate=1.0, rho=0.9, eps=1e-6):

114

"""

115

Adadelta optimizer.

116

117

Args:

118

learning_rate: Learning rate (default: 1.0)

119

rho: Decay rate for moving averages (default: 0.9)

120

eps: Small constant for numerical stability (default: 1e-6)

121

122

Returns:

123

GradientTransformation

124

"""

125

```

126

127

### Adamax Optimizer

128

129

Adamax optimizer, a variant of Adam based on the infinity norm.

130

131

```python { .api }

132

def adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-8):

133

"""

134

Adamax optimizer.

135

136

Args:

137

learning_rate: Learning rate or schedule

138

b1: Exponential decay rate for first moment estimates (default: 0.9)

139

b2: Exponential decay rate for exponentially weighted infinity norm (default: 0.999)

140

eps: Small constant for numerical stability (default: 1e-8)

141

142

Returns:

143

GradientTransformation

144

"""

145

```

146

147

### Nadam Optimizer

148

149

Nesterov-accelerated Adam optimizer combining Adam with Nesterov momentum.

150

151

```python { .api }

152

def nadam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):

153

"""

154

Nadam optimizer (Nesterov-accelerated Adam).

155

156

Args:

157

learning_rate: Learning rate or schedule

158

b1: Exponential decay rate for first moment estimates (default: 0.9)

159

b2: Exponential decay rate for second moment estimates (default: 0.999)

160

eps: Small constant for numerical stability (default: 1e-8)

161

162

Returns:

163

GradientTransformation

164

"""

165

```

166

167

### AdaBelief Optimizer

168

169

AdaBelief optimizer that adapts the step size according to the "belief" in the observed gradients.

170

171

```python { .api }

172

def adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16, *, nesterov=False):

173

"""

174

AdaBelief optimizer.

175

176

Args:

177

learning_rate: Learning rate or schedule

178

b1: Exponential decay rate for first moment estimates (default: 0.9)

179

b2: Exponential decay rate for second moment estimates (default: 0.999)

180

eps: Small constant for numerical stability (default: 1e-16)

181

eps_root: Small constant for numerical stability in denominator (default: 1e-16)

182

nesterov: Whether to use Nesterov momentum (default: False)

183

184

Returns:

185

GradientTransformation

186

"""

187

```

188

189

## Usage Example

190

191

```python

192

import optax

193

import jax.numpy as jnp

194

195

# Initialize parameters

196

params = {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros((5,))}

197

198

# Create different optimizers

199

adam_opt = optax.adam(learning_rate=0.001)

200

sgd_opt = optax.sgd(learning_rate=0.01, momentum=0.9)

201

adamw_opt = optax.adamw(learning_rate=0.001, weight_decay=1e-4)

202

203

# Initialize optimizer state

204

adam_state = adam_opt.init(params)

205

sgd_state = sgd_opt.init(params)

206

adamw_state = adamw_opt.init(params)

207

208

# In training loop (example with Adam)

209

def training_step(params, opt_state, gradients):

210

updates, new_opt_state = adam_opt.update(gradients, opt_state)

211

new_params = optax.apply_updates(params, updates)

212

return new_params, new_opt_state

213

```