0
# Neural Network Functions
1
2
JAX provides a comprehensive set of neural network functions through `jax.nn` including activation functions, normalization utilities, and attention mechanisms commonly used in machine learning and deep learning applications.
3
4
## Core Imports
5
6
```python
7
import jax.nn as jnn
8
from jax.nn import relu, sigmoid, softmax, gelu
9
```
10
11
## Capabilities
12
13
### ReLU and Variants
14
15
Rectified Linear Unit activations and their variants for introducing non-linearity while maintaining computational efficiency.
16
17
```python { .api }
18
def relu(x) -> Array:
19
"""
20
Rectified Linear Unit activation: max(0, x).
21
22
Args:
23
x: Input array
24
25
Returns:
26
Array with ReLU applied element-wise
27
"""
28
29
def relu6(x) -> Array:
30
"""
31
ReLU capped at 6: min(max(0, x), 6).
32
33
Args:
34
x: Input array
35
36
Returns:
37
Array with ReLU6 applied element-wise
38
"""
39
40
def leaky_relu(x, negative_slope=0.01) -> Array:
41
"""
42
Leaky ReLU: max(negative_slope * x, x).
43
44
Args:
45
x: Input array
46
negative_slope: Slope for negative values (default: 0.01)
47
48
Returns:
49
Array with Leaky ReLU applied element-wise
50
"""
51
52
def elu(x, alpha=1.0) -> Array:
53
"""
54
Exponential Linear Unit: x if x > 0 else alpha * (exp(x) - 1).
55
56
Args:
57
x: Input array
58
alpha: Scale for negative values (default: 1.0)
59
60
Returns:
61
Array with ELU applied element-wise
62
"""
63
64
def selu(x) -> Array:
65
"""
66
Scaled Exponential Linear Unit with fixed alpha and scale.
67
68
Args:
69
x: Input array
70
71
Returns:
72
Array with SELU applied element-wise
73
"""
74
75
def celu(x, alpha=1.0) -> Array:
76
"""
77
Continuously Differentiable Exponential Linear Unit.
78
79
Args:
80
x: Input array
81
alpha: Scale parameter (default: 1.0)
82
83
Returns:
84
Array with CELU applied element-wise
85
"""
86
```
87
88
### Modern Activations
89
90
Contemporary activation functions that have shown improved performance in various architectures.
91
92
```python { .api }
93
def gelu(x, approximate=True) -> Array:
94
"""
95
Gaussian Error Linear Unit: x * Φ(x) where Φ is CDF of standard normal.
96
97
Args:
98
x: Input array
99
approximate: Whether to use tanh approximation (default: True)
100
101
Returns:
102
Array with GELU applied element-wise
103
"""
104
105
def silu(x) -> Array:
106
"""
107
Sigmoid Linear Unit (Swish): x * sigmoid(x).
108
109
Args:
110
x: Input array
111
112
Returns:
113
Array with SiLU applied element-wise
114
"""
115
116
def swish(x) -> Array:
117
"""
118
Swish activation (alias for SiLU): x * sigmoid(x).
119
120
Args:
121
x: Input array
122
123
Returns:
124
Array with Swish applied element-wise
125
"""
126
127
def mish(x) -> Array:
128
"""
129
Mish activation: x * tanh(softplus(x)).
130
131
Args:
132
x: Input array
133
134
Returns:
135
Array with Mish applied element-wise
136
"""
137
138
def hard_silu(x) -> Array:
139
"""
140
Hard SiLU (Hard Swish variant): x * hard_sigmoid(x).
141
142
Args:
143
x: Input array
144
145
Returns:
146
Array with Hard SiLU applied element-wise
147
"""
148
149
def hard_swish(x) -> Array:
150
"""
151
Hard Swish: x * relu6(x + 3) / 6.
152
153
Args:
154
x: Input array
155
156
Returns:
157
Array with Hard Swish applied element-wise
158
"""
159
160
def squareplus(x, b=4.0) -> Array:
161
"""
162
Squareplus activation: (x + sqrt(x^2 + b)) / 2.
163
164
Args:
165
x: Input array
166
b: Shape parameter (default: 4.0)
167
168
Returns:
169
Array with Squareplus applied element-wise
170
"""
171
```
172
173
### Sigmoid and Tanh Variants
174
175
Sigmoid-based activations and their approximations for bounded outputs.
176
177
```python { .api }
178
def sigmoid(x) -> Array:
179
"""
180
Sigmoid activation: 1 / (1 + exp(-x)).
181
182
Args:
183
x: Input array
184
185
Returns:
186
Array with sigmoid applied element-wise
187
"""
188
189
def hard_sigmoid(x) -> Array:
190
"""
191
Hard sigmoid approximation: max(0, min(1, (x + 1) / 2)).
192
193
Args:
194
x: Input array
195
196
Returns:
197
Array with hard sigmoid applied element-wise
198
"""
199
200
def log_sigmoid(x) -> Array:
201
"""
202
Log sigmoid: log(sigmoid(x)) computed in numerically stable way.
203
204
Args:
205
x: Input array
206
207
Returns:
208
Array with log sigmoid applied element-wise
209
"""
210
211
def soft_sign(x) -> Array:
212
"""
213
Soft sign activation: x / (1 + |x|).
214
215
Args:
216
x: Input array
217
218
Returns:
219
Array with soft sign applied element-wise
220
"""
221
222
def tanh(x) -> Array:
223
"""
224
Hyperbolic tangent activation.
225
226
Args:
227
x: Input array
228
229
Returns:
230
Array with tanh applied element-wise
231
"""
232
233
def hard_tanh(x) -> Array:
234
"""
235
Hard tanh activation: max(-1, min(1, x)).
236
237
Args:
238
x: Input array
239
240
Returns:
241
Array with hard tanh applied element-wise
242
"""
243
```
244
245
### Softmax and Normalization
246
247
Normalization functions for probability distributions and feature standardization.
248
249
```python { .api }
250
def softmax(x, axis=-1, where=None, initial=None) -> Array:
251
"""
252
Softmax activation: exp(x_i) / sum(exp(x)) along axis.
253
254
Args:
255
x: Input array
256
axis: Axis to apply softmax along (default: -1)
257
where: Mask for conditional computation
258
initial: Initial value for reduction
259
260
Returns:
261
Array with softmax applied along specified axis
262
"""
263
264
def log_softmax(x, axis=-1, where=None, initial=None) -> Array:
265
"""
266
Log softmax: log(softmax(x)) computed in numerically stable way.
267
268
Args:
269
x: Input array
270
axis: Axis to apply log softmax along (default: -1)
271
where: Mask for conditional computation
272
initial: Initial value for reduction
273
274
Returns:
275
Array with log softmax applied along specified axis
276
"""
277
278
def softplus(x) -> Array:
279
"""
280
Softplus activation: log(1 + exp(x)).
281
282
Args:
283
x: Input array
284
285
Returns:
286
Array with softplus applied element-wise
287
"""
288
289
def standardize(x, axis=None, mean=None, variance=None, epsilon=1e-5) -> Array:
290
"""
291
Standardize array to zero mean and unit variance.
292
293
Args:
294
x: Input array to standardize
295
axis: Axis to compute statistics along
296
mean: Pre-computed mean (computed if None)
297
variance: Pre-computed variance (computed if None)
298
epsilon: Small value for numerical stability
299
300
Returns:
301
Standardized array
302
"""
303
304
def glu(x, axis=-1) -> Array:
305
"""
306
Gated Linear Unit: split x in half along axis, return a * sigmoid(b).
307
308
Args:
309
x: Input array (size along axis must be even)
310
axis: Axis to split along (default: -1)
311
312
Returns:
313
Array with GLU applied
314
"""
315
```
316
317
### Specialized Functions
318
319
Utility functions for neural network operations and transformations.
320
321
```python { .api }
322
def one_hot(x, num_classes, dtype=None, axis=-1) -> Array:
323
"""
324
One-hot encode array of integers.
325
326
Args:
327
x: Integer array to encode
328
num_classes: Number of classes
329
dtype: Output data type
330
axis: Axis to insert one-hot dimension
331
332
Returns:
333
One-hot encoded array
334
"""
335
336
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None) -> Array:
337
"""
338
Compute log(sum(exp(a))) in numerically stable way.
339
340
Args:
341
a: Input array
342
axis: Axis to sum along
343
b: Scaling factor array
344
keepdims: Whether to keep reduced dimensions
345
return_sign: Whether to return sign separately
346
where: Mask for conditional computation
347
348
Returns:
349
Log-sum-exp result
350
"""
351
352
def logmeanexp(a, axis=None, b=None, keepdims=False, where=None) -> Array:
353
"""
354
Compute log(mean(exp(a))) in numerically stable way.
355
356
Args:
357
a: Input array
358
axis: Axis to average along
359
b: Scaling factor array
360
keepdims: Whether to keep reduced dimensions
361
where: Mask for conditional computation
362
363
Returns:
364
Log-mean-exp result
365
"""
366
367
def log1mexp(x) -> Array:
368
"""
369
Compute log(1 - exp(x)) in numerically stable way.
370
371
Args:
372
x: Input array (should be <= 0)
373
374
Returns:
375
Array with log(1 - exp(x)) applied element-wise
376
"""
377
378
def sparse_plus(x, y) -> Array:
379
"""
380
Sparse-aware addition that handles missing values.
381
382
Args:
383
x: First input array
384
y: Second input array
385
386
Returns:
387
Element-wise addition result
388
"""
389
390
def sparse_sigmoid(x) -> Array:
391
"""
392
Sparse-aware sigmoid activation.
393
394
Args:
395
x: Input array
396
397
Returns:
398
Sigmoid activation with sparse support
399
"""
400
```
401
402
### Attention Mechanisms
403
404
Attention functions for transformer and neural attention models.
405
406
```python { .api }
407
def dot_product_attention(
408
query,
409
key,
410
value,
411
bias=None,
412
mask=None,
413
broadcast_dropout=True,
414
dropout_rng=None,
415
dropout_rate=0.0,
416
deterministic=False,
417
dtype=None,
418
precision=None
419
) -> Array:
420
"""
421
Dot-product attention mechanism.
422
423
Args:
424
query: Query array (..., length_q, depth_q)
425
key: Key array (..., length_kv, depth_q)
426
value: Value array (..., length_kv, depth_v)
427
bias: Optional attention bias
428
mask: Optional attention mask
429
broadcast_dropout: Whether to broadcast dropout
430
dropout_rng: Random key for dropout
431
dropout_rate: Dropout probability
432
deterministic: Whether to use deterministic mode
433
dtype: Output data type
434
precision: Computation precision
435
436
Returns:
437
Attention output array (..., length_q, depth_v)
438
"""
439
440
def scaled_dot_general(
441
lhs,
442
rhs,
443
dimension_numbers,
444
alpha=1.0,
445
precision=None,
446
preferred_element_type=None
447
) -> Array:
448
"""
449
Scaled general dot product for attention computations.
450
451
Args:
452
lhs: Left-hand side array
453
rhs: Right-hand side array
454
dimension_numbers: Contraction specification
455
alpha: Scaling factor
456
precision: Computation precision
457
preferred_element_type: Preferred output type
458
459
Returns:
460
Scaled dot product result
461
"""
462
463
def scaled_matmul(
464
a,
465
b,
466
alpha=1.0,
467
precision=None,
468
preferred_element_type=None
469
) -> Array:
470
"""
471
Scaled matrix multiplication: alpha * (a @ b).
472
473
Args:
474
a: First matrix
475
b: Second matrix
476
alpha: Scaling factor
477
precision: Computation precision
478
preferred_element_type: Preferred output type
479
480
Returns:
481
Scaled matrix multiplication result
482
"""
483
484
def get_scaled_dot_general_config() -> dict:
485
"""
486
Get configuration for scaled dot product attention.
487
488
Returns:
489
Configuration dictionary for attention operations
490
"""
491
```
492
493
### Utility Functions
494
495
Additional utilities for neural network operations.
496
497
```python { .api }
498
def identity(x) -> Array:
499
"""
500
Identity function that returns input unchanged.
501
502
Args:
503
x: Input array
504
505
Returns:
506
Input array unchanged
507
"""
508
```
509
510
## Neural Network Initializers
511
512
JAX provides weight initialization functions through `jax.nn.initializers`:
513
514
```python { .api }
515
import jax.nn.initializers as init
516
517
# Standard initializers
518
init.zeros(key, shape, dtype=jnp.float32) -> Array
519
init.ones(key, shape, dtype=jnp.float32) -> Array
520
init.constant(value, dtype=jnp.float32) -> Callable
521
522
# Random initializers
523
init.uniform(scale=1e-2, dtype=jnp.float32) -> Callable
524
init.normal(stddev=1e-2, dtype=jnp.float32) -> Callable
525
init.truncated_normal(stddev=1e-2, dtype=jnp.float32) -> Callable
526
527
# Variance scaling initializers
528
init.variance_scaling(scale, mode, distribution, dtype=jnp.float32) -> Callable
529
init.glorot_uniform(dtype=jnp.float32) -> Callable
530
init.glorot_normal(dtype=jnp.float32) -> Callable
531
init.lecun_uniform(dtype=jnp.float32) -> Callable
532
init.lecun_normal(dtype=jnp.float32) -> Callable
533
init.he_uniform(dtype=jnp.float32) -> Callable
534
init.he_normal(dtype=jnp.float32) -> Callable
535
536
# Orthogonal initializer
537
init.orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
538
539
# Delta orthogonal initializer (for RNNs)
540
init.delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
541
```
542
543
Usage examples:
544
545
```python
546
import jax
547
import jax.numpy as jnp
548
import jax.nn as jnn
549
from jax.nn import initializers as init
550
551
# Initialize weights
552
key = jax.random.key(42)
553
weights = init.glorot_uniform()(key, (784, 128))
554
biases = init.zeros(key, (128,))
555
556
# Apply activations in a simple neural network layer
557
def dense_layer(x, weights, biases):
558
return jnn.relu(x @ weights + biases)
559
560
# Multi-layer example with different activations
561
def mlp(x, params):
562
x = jnn.relu(x @ params['w1'] + params['b1'])
563
x = jnn.gelu(x @ params['w2'] + params['b2'])
564
x = jnn.softmax(x @ params['w3'] + params['b3'])
565
return x
566
567
# Attention example
568
def simple_attention(q, k, v):
569
# Scaled dot-product attention
570
scores = jnn.dot_product_attention(q, k, v)
571
return scores
572
```