or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

utilities.mddocs/

0

# Utilities and Tree Operations

1

2

Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees. These functions provide essential infrastructure for building and using optimizers effectively.

3

4

## Capabilities

5

6

### Parameter Updates

7

8

#### Core Update Functions

9

10

```python { .api }

11

def apply_updates(params, updates):

12

"""

13

Apply parameter updates to current parameters.

14

15

Args:

16

params: Current parameters (pytree)

17

updates: Parameter updates (pytree with same structure as params)

18

19

Returns:

20

Updated parameters (pytree)

21

"""

22

23

def incremental_update(new_tensors, old_tensors, step_size):

24

"""

25

Compute incremental update between tensor sets.

26

27

Args:

28

new_tensors: New tensor values

29

old_tensors: Old tensor values

30

step_size: Step size for interpolation

31

32

Returns:

33

Incrementally updated tensors

34

"""

35

36

def periodic_update(new_tensors, old_tensors, steps, update_period):

37

"""

38

Update tensors periodically based on step count.

39

40

Args:

41

new_tensors: New tensor values

42

old_tensors: Old tensor values

43

steps: Current step count

44

update_period: Period for updates

45

46

Returns:

47

Conditionally updated tensors

48

"""

49

```

50

51

### Numerical Utilities

52

53

#### Safe Operations

54

55

```python { .api }

56

def safe_norm(x, min_norm=0.0, ord=None):

57

"""

58

Numerically stable norm computation.

59

60

Args:

61

x: Input tensor

62

min_norm: Minimum norm value for stability (default: 0.0)

63

ord: Norm order (None, 1, 2, 'fro', etc.) (default: None for L2)

64

65

Returns:

66

Norm value with numerical stability

67

"""

68

69

def safe_root_mean_squares(x, min_rms=0.0):

70

"""

71

Numerically stable root mean square computation.

72

73

Args:

74

x: Input tensor

75

min_rms: Minimum RMS value for stability (default: 0.0)

76

77

Returns:

78

RMS value with numerical stability

79

"""

80

81

def safe_increment(count):

82

"""

83

Safely increment counter with overflow protection.

84

85

Args:

86

count: Current counter value

87

88

Returns:

89

Incremented counter value

90

"""

91

92

def safe_int32_increment(count):

93

"""

94

Safely increment int32 counter with overflow protection.

95

96

Args:

97

count: Current int32 counter value

98

99

Returns:

100

Incremented int32 counter value

101

"""

102

```

103

104

### Linear Algebra

105

106

#### Matrix Operations

107

108

```python { .api }

109

def global_norm(updates):

110

"""

111

Compute global norm across all parameters in pytree.

112

113

Args:

114

updates: Parameter updates (pytree)

115

116

Returns:

117

Global norm scalar value

118

"""

119

120

def power_iteration(matrix, num_iters=10, error_tolerance=1e-6, precision=None):

121

"""

122

Compute dominant eigenvalue and eigenvector using power iteration.

123

124

Args:

125

matrix: Input matrix

126

num_iters: Maximum number of iterations (default: 10)

127

error_tolerance: Convergence tolerance (default: 1e-6)

128

precision: Numerical precision (default: None)

129

130

Returns:

131

Tuple of (eigenvalue, eigenvector)

132

"""

133

134

def matrix_inverse_pth_root(matrix, p, num_iters=15, ridge_epsilon=1e-6, error_tolerance=1e-6, precision=None):

135

"""

136

Compute matrix inverse p-th root using Newton's method.

137

138

Args:

139

matrix: Input positive definite matrix

140

p: Root order (e.g., 2 for square root)

141

num_iters: Maximum iterations (default: 15)

142

ridge_epsilon: Ridge regularization (default: 1e-6)

143

error_tolerance: Convergence tolerance (default: 1e-6)

144

precision: Numerical precision (default: None)

145

146

Returns:

147

Matrix inverse p-th root

148

"""

149

150

def nnls(a, b, max_iters=None, tol=1e-8):

151

"""

152

Non-negative least squares solver.

153

154

Args:

155

a: Coefficient matrix

156

b: Target vector

157

max_iters: Maximum iterations (default: None for auto)

158

tol: Convergence tolerance (default: 1e-8)

159

160

Returns:

161

Non-negative solution vector

162

"""

163

```

164

165

### Core Types and Base Functions

166

167

#### Base Transformations

168

169

```python { .api }

170

def identity():

171

"""

172

Identity transformation that passes gradients unchanged.

173

174

Returns:

175

GradientTransformation

176

"""

177

178

def set_to_zero():

179

"""

180

Transformation that sets all gradients to zero.

181

182

Returns:

183

GradientTransformation

184

"""

185

186

def stateless(f):

187

"""

188

Create stateless transformation from function.

189

190

Args:

191

f: Function to convert to transformation

192

193

Returns:

194

GradientTransformation

195

"""

196

197

def stateless_with_tree_map(f):

198

"""

199

Create stateless transformation with tree mapping.

200

201

Args:

202

f: Function to apply to each leaf of parameter tree

203

204

Returns:

205

GradientTransformation

206

"""

207

208

def with_extra_args_support(transformation):

209

"""

210

Add support for extra arguments to transformation.

211

212

Args:

213

transformation: Base transformation to extend

214

215

Returns:

216

GradientTransformationExtraArgs

217

"""

218

```

219

220

### Utility Functions

221

222

#### Gradient Processing

223

224

```python { .api }

225

def scale_gradient(inputs, scale):

226

"""

227

Scale gradients during forward/backward pass.

228

229

Args:

230

inputs: Input values (forward pass is identity)

231

scale: Scale factor for gradients in backward pass

232

233

Returns:

234

Inputs (unchanged in forward pass)

235

"""

236

237

def value_and_grad_from_state(fun, argnums=0, has_aux=False):

238

"""

239

Compute value and gradient while maintaining state.

240

241

Args:

242

fun: Function to differentiate

243

argnums: Argument indices to differentiate (default: 0)

244

has_aux: Whether function returns auxiliary data (default: False)

245

246

Returns:

247

Function that returns (value, grad) tuple

248

"""

249

```

250

251

#### Random Utilities

252

253

```python { .api }

254

def multi_normal(loc, scale_tril, random_key):

255

"""

256

Sample from multivariate normal distribution.

257

258

Args:

259

loc: Mean vector

260

scale_tril: Lower triangular scale matrix

261

random_key: JAX random key

262

263

Returns:

264

Random sample from multivariate normal

265

"""

266

```

267

268

### Tree Operations

269

270

#### Basic Tree Arithmetic

271

272

```python { .api }

273

# Tree-level operations in optax.tree module

274

def add(tree1, tree2):

275

"""Element-wise addition of two pytrees."""

276

277

def sub(tree1, tree2):

278

"""Element-wise subtraction of two pytrees."""

279

280

def mul(tree1, tree2):

281

"""Element-wise multiplication of two pytrees."""

282

283

def div(tree1, tree2):

284

"""Element-wise division of two pytrees."""

285

286

def scale(tree, scalar):

287

"""Scale all elements in pytree by scalar."""

288

289

def norm(tree, ord=2):

290

"""Compute norm of pytree."""

291

292

def sum(tree):

293

"""Sum all elements in pytree."""

294

295

def max(tree):

296

"""Find maximum element in pytree."""

297

```

298

299

#### Tree Utilities

300

301

```python { .api }

302

def zeros_like(tree):

303

"""Create pytree of zeros with same structure."""

304

305

def ones_like(tree):

306

"""Create pytree of ones with same structure."""

307

308

def full_like(tree, fill_value):

309

"""Create pytree filled with specified value."""

310

```

311

312

### Assignment Module

313

314

#### Hungarian Algorithm

315

316

```python { .api }

317

def hungarian_algorithm(cost_matrix):

318

"""

319

Hungarian algorithm for solving assignment problems.

320

321

Args:

322

cost_matrix: 2D cost matrix for assignments

323

324

Returns:

325

Optimal assignment indices

326

"""

327

```

328

329

### Tree Utils Module

330

331

#### Parameter Tree Manipulation

332

333

```python { .api }

334

def tree_map_params(fn, tree):

335

"""

336

Map function over parameters in pytree.

337

338

Args:

339

fn: Function to apply to each parameter

340

tree: Parameter pytree

341

342

Returns:

343

Transformed pytree

344

"""

345

346

def tree_bias_correction(moment, decay, count):

347

"""

348

Apply bias correction to moment estimates.

349

350

Args:

351

moment: Moment estimate

352

decay: Decay rate used for moment

353

count: Step count for bias correction

354

355

Returns:

356

Bias-corrected moment

357

"""

358

```

359

360

#### Moment Updates

361

362

```python { .api }

363

def tree_update_moment(updates, moments, decay, order):

364

"""

365

Update moment estimates for optimizer state.

366

367

Args:

368

updates: Current gradient updates

369

moments: Previous moment estimates

370

decay: Exponential decay rate

371

order: Moment order (1 for mean, 2 for variance)

372

373

Returns:

374

Updated moment estimates

375

"""

376

377

def tree_update_moment_per_elem_norm(updates, moments, decay, order):

378

"""

379

Update moments with per-element normalization.

380

381

Args:

382

updates: Current gradient updates

383

moments: Previous moment estimates

384

decay: Exponential decay rate

385

order: Moment order

386

387

Returns:

388

Updated moment estimates with per-element normalization

389

"""

390

391

def tree_update_infinity_moment(updates, moments, decay):

392

"""

393

Update infinity moments (max absolute values).

394

395

Args:

396

updates: Current gradient updates

397

moments: Previous infinity moments

398

decay: Exponential decay rate

399

400

Returns:

401

Updated infinity moments

402

"""

403

```

404

405

### Type Definitions

406

407

```python { .api }

408

# Type aliases

409

OptState = chex.ArrayTree # Optimizer state

410

Params = chex.ArrayTree # Model parameters

411

Updates = Params # Gradient updates

412

Schedule = Callable[[chex.Numeric], chex.Numeric] # Schedule function

413

ScalarOrSchedule = Union[float, jax.Array, Schedule] # Flexible numeric type

414

MaskOrFn = Union[Any, Callable[[Params], Any]] # Mask or masking function

415

416

# Function type definitions

417

TransformInitFn = Callable[[Params], OptState]

418

TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]

419

TransformUpdateExtraArgsFn = Callable[[Updates, OptState, Optional[Params], ...], Tuple[Updates, OptState]]

420

421

# Core classes

422

class GradientTransformation(NamedTuple):

423

"""Core gradient transformation with init and update functions."""

424

init: TransformInitFn

425

update: TransformUpdateFn

426

427

class GradientTransformationExtraArgs(NamedTuple):

428

"""Extended transformation supporting extra arguments."""

429

init: TransformInitFn

430

update: TransformUpdateExtraArgsFn

431

432

class EmptyState(NamedTuple):

433

"""Empty state for stateless transformations."""

434

pass

435

436

class FactoredState(NamedTuple):

437

"""State for factorized operations."""

438

count: chex.Array

439

v_row: chex.ArrayTree

440

v_col: chex.ArrayTree

441

```

442

443

## Usage Examples

444

445

### Basic Parameter Updates

446

447

```python

448

import jax.numpy as jnp

449

import optax

450

451

# Parameters and updates

452

params = {'w': jnp.ones((5, 3)), 'b': jnp.zeros((3,))}

453

updates = {'w': jnp.ones((5, 3)) * 0.01, 'b': jnp.ones((3,)) * 0.001}

454

455

# Apply updates

456

new_params = optax.apply_updates(params, updates)

457

458

# Compute global norm

459

grad_norm = optax.global_norm(updates)

460

print(f"Global gradient norm: {grad_norm}")

461

```

462

463

### Numerical Stability

464

465

```python

466

# Safe operations for numerical stability

467

x = jnp.array([1e-8, 1e-6, 1.0, 1e6])

468

469

safe_norm_val = optax.safe_norm(x, min_norm=1e-8)

470

safe_rms_val = optax.safe_root_mean_squares(x, min_rms=1e-8)

471

472

# Safe counting

473

step_count = jnp.array(2147483647, dtype=jnp.int32) # Near int32 max

474

next_count = optax.safe_int32_increment(step_count)

475

```

476

477

### Tree Operations

478

479

```python

480

# Tree arithmetic

481

tree1 = {'a': jnp.array([1, 2, 3]), 'b': jnp.array([4, 5])}

482

tree2 = {'a': jnp.array([6, 7, 8]), 'b': jnp.array([9, 10])}

483

484

# Element-wise operations

485

sum_tree = optax.tree.add(tree1, tree2)

486

scaled_tree = optax.tree.scale(tree1, 0.5)

487

tree_norm = optax.tree.norm(tree1)

488

489

# Tree utilities

490

zero_tree = optax.tree.zeros_like(tree1)

491

ones_tree = optax.tree.ones_like(tree1)

492

```

493

494

### Custom Transformations

495

496

```python

497

# Create custom stateless transformation

498

def my_scaling_fn(updates):

499

return jax.tree_map(lambda x: 0.01 * x, updates)

500

501

my_transform = optax.stateless(my_scaling_fn)

502

503

# Use with other transformations

504

optimizer = optax.chain(

505

optax.clip_by_global_norm(1.0),

506

my_transform,

507

optax.scale_by_adam()

508

)

509

```

510

511

### Advanced Usage

512

513

```python

514

# Matrix operations for second-order methods

515

def compute_preconditioner(gradients):

516

# Flatten gradients for matrix operations

517

flat_grads = jax.flatten_util.ravel_pytree(gradients)[0]

518

519

# Compute outer product approximation

520

outer_prod = jnp.outer(flat_grads, flat_grads)

521

522

# Compute matrix inverse square root

523

inv_sqrt = optax.matrix_inverse_pth_root(

524

outer_prod + 1e-6 * jnp.eye(len(flat_grads)),

525

p=2,

526

num_iters=10

527

)

528

529

return inv_sqrt

530

531

# Gradient scaling with state

532

def scale_with_state(inputs, state):

533

scale_factor = jnp.sqrt(state['step_count'])

534

return optax.scale_gradient(inputs, scale_factor)

535

```