or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

transformations.mddocs/

0

# Gradient Transformations

1

2

Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These transformations can be combined using `chain()` to build custom optimization strategies with fine-grained control over gradient processing.

3

4

## Capabilities

5

6

### Chaining Transformations

7

8

Combine multiple gradient transformations into a single optimizer.

9

10

```python { .api }

11

def chain(*args):

12

"""

13

Chain multiple gradient transformations.

14

15

Args:

16

*args: Variable number of GradientTransformation objects

17

18

Returns:

19

GradientTransformationExtraArgs: Combined transformation

20

"""

21

22

def named_chain(**transformations):

23

"""

24

Chain transformations with names for easier debugging.

25

26

Args:

27

**transformations: Named GradientTransformation objects

28

29

Returns:

30

GradientTransformation: Combined transformation with named states

31

"""

32

```

33

34

### Scaling Transformations

35

36

#### Basic Scaling

37

38

```python { .api }

39

def scale(step_size):

40

"""

41

Scale updates by a constant factor.

42

43

Args:

44

step_size: Scaling factor (typically negative learning rate)

45

46

Returns:

47

GradientTransformation

48

"""

49

50

def scale_by_learning_rate(learning_rate):

51

"""

52

Scale updates by learning rate (with negative sign).

53

54

Args:

55

learning_rate: Learning rate value or schedule

56

57

Returns:

58

GradientTransformation

59

"""

60

61

def scale_by_schedule(schedule):

62

"""

63

Scale updates by a schedule function.

64

65

Args:

66

schedule: Schedule function taking step count and returning scale factor

67

68

Returns:

69

GradientTransformation

70

"""

71

```

72

73

#### Adaptive Scaling

74

75

```python { .api }

76

def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False):

77

"""

78

Scale updates using Adam-style adaptive scaling.

79

80

Args:

81

b1: Exponential decay rate for first moment estimates (default: 0.9)

82

b2: Exponential decay rate for second moment estimates (default: 0.999)

83

eps: Small constant for numerical stability (default: 1e-8)

84

nesterov: Whether to use Nesterov momentum (default: False)

85

86

Returns:

87

GradientTransformation

88

"""

89

90

def scale_by_rms(decay=0.9, eps=1e-8):

91

"""

92

Scale updates by root mean square of gradients.

93

94

Args:

95

decay: Decay rate for moving average (default: 0.9)

96

eps: Small constant for numerical stability (default: 1e-8)

97

98

Returns:

99

GradientTransformation

100

"""

101

102

def scale_by_stddev(decay=0.9, eps=1e-8):

103

"""

104

Scale updates by standard deviation of gradients.

105

106

Args:

107

decay: Decay rate for moving average (default: 0.9)

108

eps: Small constant for numerical stability (default: 1e-8)

109

110

Returns:

111

GradientTransformation

112

"""

113

```

114

115

### Momentum and Accumulation

116

117

```python { .api }

118

def trace(decay, nesterov=False, accumulator_dtype=None):

119

"""

120

Add momentum/trace to gradient updates.

121

122

Args:

123

decay: Decay rate for momentum (default: 0.9)

124

nesterov: Whether to use Nesterov momentum (default: False)

125

accumulator_dtype: Data type for accumulator (default: None)

126

127

Returns:

128

GradientTransformation

129

"""

130

131

def ema(decay, debias=True, accumulator_dtype=None):

132

"""

133

Exponential moving average of parameters.

134

135

Args:

136

decay: Decay rate for moving average (default: 0.9)

137

debias: Whether to debias the moving average (default: True)

138

accumulator_dtype: Data type for accumulator (default: None)

139

140

Returns:

141

GradientTransformation

142

"""

143

```

144

145

### Gradient Clipping

146

147

```python { .api }

148

def clip(max_delta):

149

"""

150

Clip updates element-wise to maximum absolute value.

151

152

Args:

153

max_delta: Maximum absolute value for updates

154

155

Returns:

156

GradientTransformation

157

"""

158

159

def clip_by_global_norm(max_norm):

160

"""

161

Clip updates by global norm.

162

163

Args:

164

max_norm: Maximum global norm for updates

165

166

Returns:

167

GradientTransformation

168

"""

169

170

def clip_by_block_rms(threshold):

171

"""

172

Clip updates by block-wise RMS.

173

174

Args:

175

threshold: RMS threshold for clipping

176

177

Returns:

178

GradientTransformation

179

"""

180

181

def adaptive_grad_clip(clipping, eps=1e-3):

182

"""

183

Adaptive gradient clipping.

184

185

Args:

186

clipping: Clipping threshold

187

eps: Small constant for numerical stability (default: 1e-3)

188

189

Returns:

190

GradientTransformation

191

"""

192

193

def per_example_global_norm_clip(l2_norm_clip, single_batch_element=False):

194

"""

195

Per-example gradient clipping for differential privacy.

196

197

Args:

198

l2_norm_clip: L2 norm clipping threshold

199

single_batch_element: Whether input is a single batch element (default: False)

200

201

Returns:

202

GradientTransformation

203

"""

204

```

205

206

### Regularization

207

208

```python { .api }

209

def add_decayed_weights(weight_decay, mask=None):

210

"""

211

Add L2 weight decay (weight regularization).

212

213

Args:

214

weight_decay: Weight decay coefficient

215

mask: Optional mask for parameter selection

216

217

Returns:

218

GradientTransformation

219

"""

220

221

def add_noise(eta, gamma, seed):

222

"""

223

Add gradient noise for improved generalization.

224

225

Args:

226

eta: Noise scaling parameter

227

gamma: Annealing rate for noise

228

seed: Random seed

229

230

Returns:

231

GradientTransformation

232

"""

233

```

234

235

### Conditioning and Normalization

236

237

```python { .api }

238

def centralize():

239

"""

240

Centralize gradients by subtracting their mean.

241

242

Returns:

243

GradientTransformation

244

"""

245

246

def normalize_by_update_norm():

247

"""

248

Normalize updates by their norm.

249

250

Returns:

251

GradientTransformation

252

"""

253

254

def scale_by_trust_ratio():

255

"""

256

Scale updates by trust ratio (parameter norm / update norm).

257

258

Returns:

259

GradientTransformation

260

"""

261

```

262

263

### Conditional Operations

264

265

```python { .api }

266

def apply_if_finite(transformation):

267

"""

268

Apply transformation only if gradients are finite.

269

270

Args:

271

transformation: Transformation to apply conditionally

272

273

Returns:

274

GradientTransformation

275

"""

276

277

def apply_every(k, transformation):

278

"""

279

Apply transformation every k steps.

280

281

Args:

282

k: Step interval

283

transformation: Transformation to apply periodically

284

285

Returns:

286

GradientTransformation

287

"""

288

289

def conditionally_transform(condition_fn, transformation):

290

"""

291

Apply transformation based on condition function.

292

293

Args:

294

condition_fn: Function that returns boolean condition

295

transformation: Transformation to apply conditionally

296

297

Returns:

298

GradientTransformation

299

"""

300

```

301

302

### Parameter Partitioning

303

304

```python { .api }

305

def partition(selector_fn, *transformations):

306

"""

307

Apply different transformations to different parameter subsets.

308

309

Args:

310

selector_fn: Function to select parameter subsets

311

*transformations: Transformations for each subset

312

313

Returns:

314

GradientTransformation

315

"""

316

317

def masked(mask_fn, transformation):

318

"""

319

Apply transformation with parameter masking.

320

321

Args:

322

mask_fn: Function to generate parameter mask

323

transformation: Transformation to apply with mask

324

325

Returns:

326

GradientTransformation

327

"""

328

```

329

330

### Parameter Constraints

331

332

```python { .api }

333

def keep_params_nonnegative():

334

"""

335

Keep parameters non-negative by projecting to positive orthant.

336

337

Returns:

338

GradientTransformation

339

"""

340

341

def zero_nans():

342

"""

343

Set NaN gradients to zero.

344

345

Returns:

346

GradientTransformation

347

"""

348

```

349

350

### Multi-Step Accumulation

351

352

```python { .api }

353

class MultiSteps:

354

"""Multi-step gradient accumulation."""

355

356

def __init__(self, every_k_schedule, use_grad_mean=True):

357

"""

358

Initialize multi-step accumulation.

359

360

Args:

361

every_k_schedule: Schedule for accumulation steps

362

use_grad_mean: Whether to use gradient mean instead of sum (default: True)

363

"""

364

365

def skip_not_finite(updates, state, params=None):

366

"""

367

Skip updates that are not finite.

368

369

Args:

370

updates: Gradient updates

371

state: Optimizer state

372

params: Optional parameters

373

374

Returns:

375

Tuple of (updates, state)

376

"""

377

378

def skip_large_updates(updates, state, max_norm):

379

"""

380

Skip updates with norm larger than threshold.

381

382

Args:

383

updates: Gradient updates

384

state: Optimizer state

385

max_norm: Maximum allowed update norm

386

387

Returns:

388

Tuple of (updates, state)

389

"""

390

```

391

392

## Usage Examples

393

394

### Custom Optimizer with Chaining

395

396

```python

397

import optax

398

399

# Create custom optimizer by chaining transformations

400

custom_optimizer = optax.chain(

401

optax.clip_by_global_norm(1.0), # Gradient clipping

402

optax.add_decayed_weights(weight_decay=1e-4), # Weight decay

403

optax.scale_by_adam(b1=0.9, b2=0.999), # Adam scaling

404

optax.scale(-0.001) # Learning rate

405

)

406

407

# Initialize with parameters

408

params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros((5,))}

409

opt_state = custom_optimizer.init(params)

410

```

411

412

### Conditional and Partitioned Updates

413

414

```python

415

# Apply different learning rates to different parameter groups

416

def is_bias(path, param):

417

return 'bias' in path

418

419

bias_tx = optax.scale(-0.01) # Higher learning rate for biases

420

weight_tx = optax.scale(-0.001) # Lower learning rate for weights

421

422

partitioned_optimizer = optax.partition(is_bias, bias_tx, weight_tx)

423

424

# Apply transformation only every 5 steps

425

sparse_optimizer = optax.apply_every(5, optax.adam(0.001))

426

```

427

428

### Robust Training Setup

429

430

```python

431

# Robust optimizer with multiple safeguards

432

robust_optimizer = optax.chain(

433

optax.clip_by_global_norm(1.0), # Prevent exploding gradients

434

optax.apply_if_finite( # Skip non-finite updates

435

optax.chain(

436

optax.centralize(), # Center gradients

437

optax.scale_by_adam(), # Adaptive scaling

438

optax.add_decayed_weights(1e-4), # Weight regularization

439

)

440

),

441

optax.scale_by_schedule( # Learning rate schedule

442

optax.cosine_decay_schedule(0.001, 1000)

443

)

444

)

445

```