or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

neural-networks.mddocs/

0

# Neural Network Functions

1

2

JAX provides a comprehensive set of neural network functions through `jax.nn` including activation functions, normalization utilities, and attention mechanisms commonly used in machine learning and deep learning applications.

3

4

## Core Imports

5

6

```python

7

import jax.nn as jnn

8

from jax.nn import relu, sigmoid, softmax, gelu

9

```

10

11

## Capabilities

12

13

### ReLU and Variants

14

15

Rectified Linear Unit activations and their variants for introducing non-linearity while maintaining computational efficiency.

16

17

```python { .api }

18

def relu(x) -> Array:

19

"""

20

Rectified Linear Unit activation: max(0, x).

21

22

Args:

23

x: Input array

24

25

Returns:

26

Array with ReLU applied element-wise

27

"""

28

29

def relu6(x) -> Array:

30

"""

31

ReLU capped at 6: min(max(0, x), 6).

32

33

Args:

34

x: Input array

35

36

Returns:

37

Array with ReLU6 applied element-wise

38

"""

39

40

def leaky_relu(x, negative_slope=0.01) -> Array:

41

"""

42

Leaky ReLU: max(negative_slope * x, x).

43

44

Args:

45

x: Input array

46

negative_slope: Slope for negative values (default: 0.01)

47

48

Returns:

49

Array with Leaky ReLU applied element-wise

50

"""

51

52

def elu(x, alpha=1.0) -> Array:

53

"""

54

Exponential Linear Unit: x if x > 0 else alpha * (exp(x) - 1).

55

56

Args:

57

x: Input array

58

alpha: Scale for negative values (default: 1.0)

59

60

Returns:

61

Array with ELU applied element-wise

62

"""

63

64

def selu(x) -> Array:

65

"""

66

Scaled Exponential Linear Unit with fixed alpha and scale.

67

68

Args:

69

x: Input array

70

71

Returns:

72

Array with SELU applied element-wise

73

"""

74

75

def celu(x, alpha=1.0) -> Array:

76

"""

77

Continuously Differentiable Exponential Linear Unit.

78

79

Args:

80

x: Input array

81

alpha: Scale parameter (default: 1.0)

82

83

Returns:

84

Array with CELU applied element-wise

85

"""

86

```

87

88

### Modern Activations

89

90

Contemporary activation functions that have shown improved performance in various architectures.

91

92

```python { .api }

93

def gelu(x, approximate=True) -> Array:

94

"""

95

Gaussian Error Linear Unit: x * Φ(x) where Φ is CDF of standard normal.

96

97

Args:

98

x: Input array

99

approximate: Whether to use tanh approximation (default: True)

100

101

Returns:

102

Array with GELU applied element-wise

103

"""

104

105

def silu(x) -> Array:

106

"""

107

Sigmoid Linear Unit (Swish): x * sigmoid(x).

108

109

Args:

110

x: Input array

111

112

Returns:

113

Array with SiLU applied element-wise

114

"""

115

116

def swish(x) -> Array:

117

"""

118

Swish activation (alias for SiLU): x * sigmoid(x).

119

120

Args:

121

x: Input array

122

123

Returns:

124

Array with Swish applied element-wise

125

"""

126

127

def mish(x) -> Array:

128

"""

129

Mish activation: x * tanh(softplus(x)).

130

131

Args:

132

x: Input array

133

134

Returns:

135

Array with Mish applied element-wise

136

"""

137

138

def hard_silu(x) -> Array:

139

"""

140

Hard SiLU (Hard Swish variant): x * hard_sigmoid(x).

141

142

Args:

143

x: Input array

144

145

Returns:

146

Array with Hard SiLU applied element-wise

147

"""

148

149

def hard_swish(x) -> Array:

150

"""

151

Hard Swish: x * relu6(x + 3) / 6.

152

153

Args:

154

x: Input array

155

156

Returns:

157

Array with Hard Swish applied element-wise

158

"""

159

160

def squareplus(x, b=4.0) -> Array:

161

"""

162

Squareplus activation: (x + sqrt(x^2 + b)) / 2.

163

164

Args:

165

x: Input array

166

b: Shape parameter (default: 4.0)

167

168

Returns:

169

Array with Squareplus applied element-wise

170

"""

171

```

172

173

### Sigmoid and Tanh Variants

174

175

Sigmoid-based activations and their approximations for bounded outputs.

176

177

```python { .api }

178

def sigmoid(x) -> Array:

179

"""

180

Sigmoid activation: 1 / (1 + exp(-x)).

181

182

Args:

183

x: Input array

184

185

Returns:

186

Array with sigmoid applied element-wise

187

"""

188

189

def hard_sigmoid(x) -> Array:

190

"""

191

Hard sigmoid approximation: max(0, min(1, (x + 1) / 2)).

192

193

Args:

194

x: Input array

195

196

Returns:

197

Array with hard sigmoid applied element-wise

198

"""

199

200

def log_sigmoid(x) -> Array:

201

"""

202

Log sigmoid: log(sigmoid(x)) computed in numerically stable way.

203

204

Args:

205

x: Input array

206

207

Returns:

208

Array with log sigmoid applied element-wise

209

"""

210

211

def soft_sign(x) -> Array:

212

"""

213

Soft sign activation: x / (1 + |x|).

214

215

Args:

216

x: Input array

217

218

Returns:

219

Array with soft sign applied element-wise

220

"""

221

222

def tanh(x) -> Array:

223

"""

224

Hyperbolic tangent activation.

225

226

Args:

227

x: Input array

228

229

Returns:

230

Array with tanh applied element-wise

231

"""

232

233

def hard_tanh(x) -> Array:

234

"""

235

Hard tanh activation: max(-1, min(1, x)).

236

237

Args:

238

x: Input array

239

240

Returns:

241

Array with hard tanh applied element-wise

242

"""

243

```

244

245

### Softmax and Normalization

246

247

Normalization functions for probability distributions and feature standardization.

248

249

```python { .api }

250

def softmax(x, axis=-1, where=None, initial=None) -> Array:

251

"""

252

Softmax activation: exp(x_i) / sum(exp(x)) along axis.

253

254

Args:

255

x: Input array

256

axis: Axis to apply softmax along (default: -1)

257

where: Mask for conditional computation

258

initial: Initial value for reduction

259

260

Returns:

261

Array with softmax applied along specified axis

262

"""

263

264

def log_softmax(x, axis=-1, where=None, initial=None) -> Array:

265

"""

266

Log softmax: log(softmax(x)) computed in numerically stable way.

267

268

Args:

269

x: Input array

270

axis: Axis to apply log softmax along (default: -1)

271

where: Mask for conditional computation

272

initial: Initial value for reduction

273

274

Returns:

275

Array with log softmax applied along specified axis

276

"""

277

278

def softplus(x) -> Array:

279

"""

280

Softplus activation: log(1 + exp(x)).

281

282

Args:

283

x: Input array

284

285

Returns:

286

Array with softplus applied element-wise

287

"""

288

289

def standardize(x, axis=None, mean=None, variance=None, epsilon=1e-5) -> Array:

290

"""

291

Standardize array to zero mean and unit variance.

292

293

Args:

294

x: Input array to standardize

295

axis: Axis to compute statistics along

296

mean: Pre-computed mean (computed if None)

297

variance: Pre-computed variance (computed if None)

298

epsilon: Small value for numerical stability

299

300

Returns:

301

Standardized array

302

"""

303

304

def glu(x, axis=-1) -> Array:

305

"""

306

Gated Linear Unit: split x in half along axis, return a * sigmoid(b).

307

308

Args:

309

x: Input array (size along axis must be even)

310

axis: Axis to split along (default: -1)

311

312

Returns:

313

Array with GLU applied

314

"""

315

```

316

317

### Specialized Functions

318

319

Utility functions for neural network operations and transformations.

320

321

```python { .api }

322

def one_hot(x, num_classes, dtype=None, axis=-1) -> Array:

323

"""

324

One-hot encode array of integers.

325

326

Args:

327

x: Integer array to encode

328

num_classes: Number of classes

329

dtype: Output data type

330

axis: Axis to insert one-hot dimension

331

332

Returns:

333

One-hot encoded array

334

"""

335

336

def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None) -> Array:

337

"""

338

Compute log(sum(exp(a))) in numerically stable way.

339

340

Args:

341

a: Input array

342

axis: Axis to sum along

343

b: Scaling factor array

344

keepdims: Whether to keep reduced dimensions

345

return_sign: Whether to return sign separately

346

where: Mask for conditional computation

347

348

Returns:

349

Log-sum-exp result

350

"""

351

352

def logmeanexp(a, axis=None, b=None, keepdims=False, where=None) -> Array:

353

"""

354

Compute log(mean(exp(a))) in numerically stable way.

355

356

Args:

357

a: Input array

358

axis: Axis to average along

359

b: Scaling factor array

360

keepdims: Whether to keep reduced dimensions

361

where: Mask for conditional computation

362

363

Returns:

364

Log-mean-exp result

365

"""

366

367

def log1mexp(x) -> Array:

368

"""

369

Compute log(1 - exp(x)) in numerically stable way.

370

371

Args:

372

x: Input array (should be <= 0)

373

374

Returns:

375

Array with log(1 - exp(x)) applied element-wise

376

"""

377

378

def sparse_plus(x, y) -> Array:

379

"""

380

Sparse-aware addition that handles missing values.

381

382

Args:

383

x: First input array

384

y: Second input array

385

386

Returns:

387

Element-wise addition result

388

"""

389

390

def sparse_sigmoid(x) -> Array:

391

"""

392

Sparse-aware sigmoid activation.

393

394

Args:

395

x: Input array

396

397

Returns:

398

Sigmoid activation with sparse support

399

"""

400

```

401

402

### Attention Mechanisms

403

404

Attention functions for transformer and neural attention models.

405

406

```python { .api }

407

def dot_product_attention(

408

query,

409

key,

410

value,

411

bias=None,

412

mask=None,

413

broadcast_dropout=True,

414

dropout_rng=None,

415

dropout_rate=0.0,

416

deterministic=False,

417

dtype=None,

418

precision=None

419

) -> Array:

420

"""

421

Dot-product attention mechanism.

422

423

Args:

424

query: Query array (..., length_q, depth_q)

425

key: Key array (..., length_kv, depth_q)

426

value: Value array (..., length_kv, depth_v)

427

bias: Optional attention bias

428

mask: Optional attention mask

429

broadcast_dropout: Whether to broadcast dropout

430

dropout_rng: Random key for dropout

431

dropout_rate: Dropout probability

432

deterministic: Whether to use deterministic mode

433

dtype: Output data type

434

precision: Computation precision

435

436

Returns:

437

Attention output array (..., length_q, depth_v)

438

"""

439

440

def scaled_dot_general(

441

lhs,

442

rhs,

443

dimension_numbers,

444

alpha=1.0,

445

precision=None,

446

preferred_element_type=None

447

) -> Array:

448

"""

449

Scaled general dot product for attention computations.

450

451

Args:

452

lhs: Left-hand side array

453

rhs: Right-hand side array

454

dimension_numbers: Contraction specification

455

alpha: Scaling factor

456

precision: Computation precision

457

preferred_element_type: Preferred output type

458

459

Returns:

460

Scaled dot product result

461

"""

462

463

def scaled_matmul(

464

a,

465

b,

466

alpha=1.0,

467

precision=None,

468

preferred_element_type=None

469

) -> Array:

470

"""

471

Scaled matrix multiplication: alpha * (a @ b).

472

473

Args:

474

a: First matrix

475

b: Second matrix

476

alpha: Scaling factor

477

precision: Computation precision

478

preferred_element_type: Preferred output type

479

480

Returns:

481

Scaled matrix multiplication result

482

"""

483

484

def get_scaled_dot_general_config() -> dict:

485

"""

486

Get configuration for scaled dot product attention.

487

488

Returns:

489

Configuration dictionary for attention operations

490

"""

491

```

492

493

### Utility Functions

494

495

Additional utilities for neural network operations.

496

497

```python { .api }

498

def identity(x) -> Array:

499

"""

500

Identity function that returns input unchanged.

501

502

Args:

503

x: Input array

504

505

Returns:

506

Input array unchanged

507

"""

508

```

509

510

## Neural Network Initializers

511

512

JAX provides weight initialization functions through `jax.nn.initializers`:

513

514

```python { .api }

515

import jax.nn.initializers as init

516

517

# Standard initializers

518

init.zeros(key, shape, dtype=jnp.float32) -> Array

519

init.ones(key, shape, dtype=jnp.float32) -> Array

520

init.constant(value, dtype=jnp.float32) -> Callable

521

522

# Random initializers

523

init.uniform(scale=1e-2, dtype=jnp.float32) -> Callable

524

init.normal(stddev=1e-2, dtype=jnp.float32) -> Callable

525

init.truncated_normal(stddev=1e-2, dtype=jnp.float32) -> Callable

526

527

# Variance scaling initializers

528

init.variance_scaling(scale, mode, distribution, dtype=jnp.float32) -> Callable

529

init.glorot_uniform(dtype=jnp.float32) -> Callable

530

init.glorot_normal(dtype=jnp.float32) -> Callable

531

init.lecun_uniform(dtype=jnp.float32) -> Callable

532

init.lecun_normal(dtype=jnp.float32) -> Callable

533

init.he_uniform(dtype=jnp.float32) -> Callable

534

init.he_normal(dtype=jnp.float32) -> Callable

535

536

# Orthogonal initializer

537

init.orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable

538

539

# Delta orthogonal initializer (for RNNs)

540

init.delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable

541

```

542

543

Usage examples:

544

545

```python

546

import jax

547

import jax.numpy as jnp

548

import jax.nn as jnn

549

from jax.nn import initializers as init

550

551

# Initialize weights

552

key = jax.random.key(42)

553

weights = init.glorot_uniform()(key, (784, 128))

554

biases = init.zeros(key, (128,))

555

556

# Apply activations in a simple neural network layer

557

def dense_layer(x, weights, biases):

558

return jnn.relu(x @ weights + biases)

559

560

# Multi-layer example with different activations

561

def mlp(x, params):

562

x = jnn.relu(x @ params['w1'] + params['b1'])

563

x = jnn.gelu(x @ params['w2'] + params['b2'])

564

x = jnn.softmax(x @ params['w3'] + params['b3'])

565

return x

566

567

# Attention example

568

def simple_attention(q, k, v):

569

# Scaled dot-product attention

570

scores = jnn.dot_product_attention(q, k, v)

571

return scores

572

```