0
# Random Number Generation
1
2
JAX uses a functional approach to pseudo-random number generation with explicit key management. This design enables reproducibility, parallelization, and vectorization while avoiding global state typical of other libraries.
3
4
## Core Imports
5
6
```python
7
import jax.random as jr
8
from jax.random import key, split, normal, uniform
9
```
10
11
## Key Concepts
12
13
JAX random numbers require explicit key management:
14
- Keys are created from integer seeds
15
- Keys must be split to generate independent random sequences
16
- Each random function consumes a key and returns deterministic output
17
- No global random state - all randomness is explicit
18
19
## Capabilities
20
21
### Key Management
22
23
Generate, split, and manipulate PRNG keys for deterministic random number generation.
24
25
```python { .api }
26
def key(seed: int, impl=None) -> Array:
27
"""
28
Create a typed PRNG key from integer seed.
29
30
Args:
31
seed: Integer seed value
32
impl: PRNG implementation to use
33
34
Returns:
35
PRNG key array
36
"""
37
38
def PRNGKey(seed: int) -> Array:
39
"""
40
Create legacy PRNG key (uint32 format).
41
42
Args:
43
seed: Integer seed value
44
45
Returns:
46
Legacy format PRNG key
47
"""
48
49
def split(key: Array, num: int = 2) -> Array:
50
"""
51
Split PRNG key into multiple independent keys.
52
53
Args:
54
key: PRNG key to split
55
num: Number of keys to generate (default: 2)
56
57
Returns:
58
Array of shape (num,) + key.shape containing new keys
59
"""
60
61
def fold_in(key: Array, data: int) -> Array:
62
"""
63
Fold integer data into PRNG key.
64
65
Args:
66
key: PRNG key
67
data: Integer to fold into key
68
69
Returns:
70
New PRNG key with data folded in
71
"""
72
73
def clone(key: Array) -> Array:
74
"""
75
Clone PRNG key for reuse.
76
77
Args:
78
key: PRNG key to clone
79
80
Returns:
81
Cloned PRNG key
82
"""
83
84
def key_data(keys: Array) -> Array:
85
"""
86
Extract raw key data from PRNG keys.
87
88
Args:
89
keys: PRNG key array
90
91
Returns:
92
Raw key data
93
"""
94
95
def wrap_key_data(key_data: Array, *, impl=None) -> Array:
96
"""
97
Wrap raw key data as PRNG keys.
98
99
Args:
100
key_data: Raw key data
101
impl: PRNG implementation
102
103
Returns:
104
PRNG key array
105
"""
106
107
def key_impl(key: Array) -> str:
108
"""
109
Get PRNG implementation name for key.
110
111
Args:
112
key: PRNG key
113
114
Returns:
115
Implementation name string
116
"""
117
```
118
119
### Continuous Distributions
120
121
Sample from continuous probability distributions.
122
123
```python { .api }
124
def uniform(
125
key: Array,
126
shape=(),
127
dtype=float,
128
minval=0.0,
129
maxval=1.0
130
) -> Array:
131
"""
132
Sample from uniform distribution.
133
134
Args:
135
key: PRNG key
136
shape: Output shape
137
dtype: Output data type
138
minval: Minimum value (inclusive)
139
maxval: Maximum value (exclusive)
140
141
Returns:
142
Random samples from uniform distribution
143
"""
144
145
def normal(key: Array, shape=(), dtype=float) -> Array:
146
"""
147
Sample from standard normal (Gaussian) distribution.
148
149
Args:
150
key: PRNG key
151
shape: Output shape
152
dtype: Output data type
153
154
Returns:
155
Random samples from N(0, 1)
156
"""
157
158
def multivariate_normal(
159
key: Array,
160
mean: Array,
161
cov: Array,
162
shape=(),
163
dtype=float,
164
method='cholesky'
165
) -> Array:
166
"""
167
Sample from multivariate normal distribution.
168
169
Args:
170
key: PRNG key
171
mean: Mean vector
172
cov: Covariance matrix
173
shape: Batch shape
174
dtype: Output data type
175
method: Decomposition method ('cholesky', 'eigh', 'svd')
176
177
Returns:
178
Random samples from multivariate normal
179
"""
180
181
def truncated_normal(
182
key: Array,
183
lower: float,
184
upper: float,
185
shape=(),
186
dtype=float
187
) -> Array:
188
"""
189
Sample from truncated normal distribution.
190
191
Args:
192
key: PRNG key
193
lower: Lower truncation bound
194
upper: Upper truncation bound
195
shape: Output shape
196
dtype: Output data type
197
198
Returns:
199
Random samples from truncated normal
200
"""
201
202
def beta(key: Array, a: Array, b: Array, shape=(), dtype=float) -> Array:
203
"""
204
Sample from beta distribution.
205
206
Args:
207
key: PRNG key
208
a: Alpha parameter (concentration)
209
b: Beta parameter (concentration)
210
shape: Output shape
211
dtype: Output data type
212
213
Returns:
214
Random samples from Beta(a, b)
215
"""
216
217
def gamma(key: Array, a: Array, shape=(), dtype=float) -> Array:
218
"""
219
Sample from gamma distribution.
220
221
Args:
222
key: PRNG key
223
a: Shape parameter
224
shape: Output shape
225
dtype: Output data type
226
227
Returns:
228
Random samples from Gamma(a, 1)
229
"""
230
231
def exponential(key: Array, shape=(), dtype=float) -> Array:
232
"""
233
Sample from exponential distribution.
234
235
Args:
236
key: PRNG key
237
shape: Output shape
238
dtype: Output data type
239
240
Returns:
241
Random samples from Exponential(1)
242
"""
243
244
def laplace(key: Array, shape=(), dtype=float) -> Array:
245
"""
246
Sample from Laplace distribution.
247
248
Args:
249
key: PRNG key
250
shape: Output shape
251
dtype: Output data type
252
253
Returns:
254
Random samples from Laplace(0, 1)
255
"""
256
257
def logistic(key: Array, shape=(), dtype=float) -> Array:
258
"""
259
Sample from logistic distribution.
260
261
Args:
262
key: PRNG key
263
shape: Output shape
264
dtype: Output data type
265
266
Returns:
267
Random samples from Logistic(0, 1)
268
"""
269
270
def lognormal(key: Array, sigma=1.0, shape=(), dtype=float) -> Array:
271
"""
272
Sample from log-normal distribution.
273
274
Args:
275
key: PRNG key
276
sigma: Standard deviation of underlying normal
277
shape: Output shape
278
dtype: Output data type
279
280
Returns:
281
Random samples from log-normal distribution
282
"""
283
284
def pareto(key: Array, b: Array, shape=(), dtype=float) -> Array:
285
"""
286
Sample from Pareto distribution.
287
288
Args:
289
key: PRNG key
290
b: Shape parameter
291
shape: Output shape
292
dtype: Output data type
293
294
Returns:
295
Random samples from Pareto(b, 1)
296
"""
297
298
def cauchy(key: Array, shape=(), dtype=float) -> Array:
299
"""
300
Sample from Cauchy distribution.
301
302
Args:
303
key: PRNG key
304
shape: Output shape
305
dtype: Output data type
306
307
Returns:
308
Random samples from Cauchy(0, 1)
309
"""
310
311
def double_sided_maxwell(
312
key: Array,
313
loc: Array,
314
scale: Array,
315
shape=(),
316
dtype=float
317
) -> Array:
318
"""
319
Sample from double-sided Maxwell distribution.
320
321
Args:
322
key: PRNG key
323
loc: Location parameter
324
scale: Scale parameter
325
shape: Output shape
326
dtype: Output data type
327
328
Returns:
329
Random samples from double-sided Maxwell
330
"""
331
332
def maxwell(key: Array, shape=(), dtype=float) -> Array:
333
"""
334
Sample from Maxwell distribution.
335
336
Args:
337
key: PRNG key
338
shape: Output shape
339
dtype: Output data type
340
341
Returns:
342
Random samples from Maxwell distribution
343
"""
344
345
def rayleigh(key: Array, scale=1.0, shape=(), dtype=float) -> Array:
346
"""
347
Sample from Rayleigh distribution.
348
349
Args:
350
key: PRNG key
351
scale: Scale parameter
352
shape: Output shape
353
dtype: Output data type
354
355
Returns:
356
Random samples from Rayleigh(scale)
357
"""
358
359
def wald(key: Array, mean: Array, shape=(), dtype=float) -> Array:
360
"""
361
Sample from Wald (Inverse Gaussian) distribution.
362
363
Args:
364
key: PRNG key
365
mean: Mean parameter
366
shape: Output shape
367
dtype: Output data type
368
369
Returns:
370
Random samples from Wald distribution
371
"""
372
373
def weibull_min(
374
key: Array,
375
concentration: Array,
376
scale=1.0,
377
shape=(),
378
dtype=float
379
) -> Array:
380
"""
381
Sample from Weibull minimum distribution.
382
383
Args:
384
key: PRNG key
385
concentration: Shape parameter
386
scale: Scale parameter
387
shape: Output shape
388
dtype: Output data type
389
390
Returns:
391
Random samples from Weibull minimum
392
"""
393
394
def gumbel(key: Array, shape=(), dtype=float) -> Array:
395
"""
396
Sample from Gumbel distribution.
397
398
Args:
399
key: PRNG key
400
shape: Output shape
401
dtype: Output data type
402
403
Returns:
404
Random samples from Gumbel(0, 1)
405
"""
406
407
def chisquare(key: Array, df: Array, shape=(), dtype=float) -> Array:
408
"""
409
Sample from chi-square distribution.
410
411
Args:
412
key: PRNG key
413
df: Degrees of freedom
414
shape: Output shape
415
dtype: Output data type
416
417
Returns:
418
Random samples from chi-square(df)
419
"""
420
421
def dirichlet(
422
key: Array,
423
alpha: Array,
424
shape=(),
425
dtype=float
426
) -> Array:
427
"""
428
Sample from Dirichlet distribution.
429
430
Args:
431
key: PRNG key
432
alpha: Concentration parameters
433
shape: Batch shape
434
dtype: Output data type
435
436
Returns:
437
Random samples from Dirichlet(alpha)
438
"""
439
440
def f(key: Array, dfnum: Array, dfden: Array, shape=(), dtype=float) -> Array:
441
"""
442
Sample from F-distribution.
443
444
Args:
445
key: PRNG key
446
dfnum: Numerator degrees of freedom
447
dfden: Denominator degrees of freedom
448
shape: Output shape
449
dtype: Output data type
450
451
Returns:
452
Random samples from F-distribution
453
"""
454
455
def t(key: Array, df: Array, shape=(), dtype=float) -> Array:
456
"""
457
Sample from Student's t-distribution.
458
459
Args:
460
key: PRNG key
461
df: Degrees of freedom
462
shape: Output shape
463
dtype: Output data type
464
465
Returns:
466
Random samples from t-distribution
467
"""
468
469
def triangular(
470
key: Array,
471
left: Array,
472
mode: Array,
473
right: Array,
474
shape=(),
475
dtype=float
476
) -> Array:
477
"""
478
Sample from triangular distribution.
479
480
Args:
481
key: PRNG key
482
left: Left boundary
483
mode: Mode (peak) value
484
right: Right boundary
485
shape: Output shape
486
dtype: Output data type
487
488
Returns:
489
Random samples from triangular distribution
490
"""
491
492
def generalized_normal(
493
key: Array,
494
p: Array,
495
shape=(),
496
dtype=float
497
) -> Array:
498
"""
499
Sample from generalized normal distribution.
500
501
Args:
502
key: PRNG key
503
p: Shape parameter
504
shape: Output shape
505
dtype: Output data type
506
507
Returns:
508
Random samples from generalized normal
509
"""
510
511
def loggamma(key: Array, a: Array, shape=(), dtype=float) -> Array:
512
"""
513
Sample log-gamma random variables.
514
515
Args:
516
key: PRNG key
517
a: Shape parameter
518
shape: Output shape
519
dtype: Output data type
520
521
Returns:
522
Random samples from log-gamma distribution
523
"""
524
```
525
526
### Discrete Distributions
527
528
Sample from discrete probability distributions.
529
530
```python { .api }
531
def bernoulli(key: Array, p=0.5, shape=(), dtype=int) -> Array:
532
"""
533
Sample from Bernoulli distribution.
534
535
Args:
536
key: PRNG key
537
p: Success probability
538
shape: Output shape
539
dtype: Output data type
540
541
Returns:
542
Random samples from Bernoulli(p)
543
"""
544
545
def binomial(key: Array, n: Array, p: Array, shape=(), dtype=int) -> Array:
546
"""
547
Sample from binomial distribution.
548
549
Args:
550
key: PRNG key
551
n: Number of trials
552
p: Success probability per trial
553
shape: Output shape
554
dtype: Output data type
555
556
Returns:
557
Random samples from Binomial(n, p)
558
"""
559
560
def categorical(
561
key: Array,
562
logits: Array,
563
axis=-1,
564
shape=None
565
) -> Array:
566
"""
567
Sample from categorical distribution.
568
569
Args:
570
key: PRNG key
571
logits: Log-probability array
572
axis: Axis over which to normalize
573
shape: Output shape
574
575
Returns:
576
Random categorical indices
577
"""
578
579
def choice(
580
key: Array,
581
a: int | Array,
582
shape=(),
583
replace=True,
584
p=None,
585
axis=0
586
) -> Array:
587
"""
588
Random choice from array elements.
589
590
Args:
591
key: PRNG key
592
a: Array to sample from or integer (range)
593
shape: Output shape
594
replace: Whether to sample with replacement
595
p: Probabilities for each element
596
axis: Axis to sample along
597
598
Returns:
599
Random samples from input array
600
"""
601
602
def geometric(key: Array, p: Array, shape=(), dtype=int) -> Array:
603
"""
604
Sample from geometric distribution.
605
606
Args:
607
key: PRNG key
608
p: Success probability
609
shape: Output shape
610
dtype: Output data type
611
612
Returns:
613
Random samples from Geometric(p)
614
"""
615
616
def poisson(key: Array, lam: Array, shape=(), dtype=int) -> Array:
617
"""
618
Sample from Poisson distribution.
619
620
Args:
621
key: PRNG key
622
lam: Rate parameter
623
shape: Output shape
624
dtype: Output data type
625
626
Returns:
627
Random samples from Poisson(lam)
628
"""
629
630
def multinomial(
631
key: Array,
632
n: Array,
633
pvals: Array,
634
shape=(),
635
dtype=int
636
) -> Array:
637
"""
638
Sample from multinomial distribution.
639
640
Args:
641
key: PRNG key
642
n: Number of trials
643
pvals: Probability values for each category
644
shape: Batch shape
645
dtype: Output data type
646
647
Returns:
648
Random samples from Multinomial(n, pvals)
649
"""
650
651
def randint(
652
key: Array,
653
minval: int,
654
maxval: int,
655
shape=(),
656
dtype=int
657
) -> Array:
658
"""
659
Sample random integers from [minval, maxval).
660
661
Args:
662
key: PRNG key
663
minval: Minimum value (inclusive)
664
maxval: Maximum value (exclusive)
665
shape: Output shape
666
dtype: Output data type
667
668
Returns:
669
Random integers in specified range
670
"""
671
672
def rademacher(key: Array, shape=(), dtype=int) -> Array:
673
"""
674
Sample from Rademacher distribution (±1 with equal probability).
675
676
Args:
677
key: PRNG key
678
shape: Output shape
679
dtype: Output data type
680
681
Returns:
682
Random samples from {-1, +1}
683
"""
684
```
685
686
### Specialized Sampling
687
688
Special sampling functions for geometric shapes and structured sampling.
689
690
```python { .api }
691
def ball(key: Array, d: int, p=2, shape=(), dtype=float) -> Array:
692
"""
693
Sample uniformly from d-dimensional unit ball.
694
695
Args:
696
key: PRNG key
697
d: Dimension of ball
698
p: Norm type (default: 2 for Euclidean)
699
shape: Batch shape
700
dtype: Output data type
701
702
Returns:
703
Random samples from unit ball
704
"""
705
706
def orthogonal(key: Array, n: int, shape=(), dtype=float) -> Array:
707
"""
708
Sample random orthogonal matrix.
709
710
Args:
711
key: PRNG key
712
n: Matrix dimension
713
shape: Batch shape
714
dtype: Output data type
715
716
Returns:
717
Random orthogonal matrix of size (n, n)
718
"""
719
720
def permutation(key: Array, x: int | Array, axis=0, independent=False) -> Array:
721
"""
722
Generate random permutation of array or integers.
723
724
Args:
725
key: PRNG key
726
x: Array to permute or integer (range)
727
axis: Axis to permute along
728
independent: Whether to permute each batch element independently
729
730
Returns:
731
Randomly permuted array
732
"""
733
734
def bits(key: Array, width=64, shape=(), dtype=None) -> Array:
735
"""
736
Generate random bits.
737
738
Args:
739
key: PRNG key
740
width: Number of bits per sample
741
shape: Output shape
742
dtype: Output data type
743
744
Returns:
745
Random bit patterns
746
"""
747
```
748
749
## Usage Examples
750
751
Common patterns for JAX random number generation:
752
753
```python
754
import jax
755
import jax.numpy as jnp
756
import jax.random as jr
757
758
# Create and split keys
759
main_key = jr.key(42)
760
key1, key2, key3 = jr.split(main_key, 3)
761
762
# Basic sampling
763
samples = jr.normal(key1, (1000,))
764
random_ints = jr.randint(key2, 0, 10, (100,))
765
766
# Batch sampling with same key
767
batch_samples = jr.normal(key3, (32, 784)) # 32 samples of 784 dims
768
769
# Different keys for each batch element
770
keys = jr.split(main_key, 32)
771
independent_samples = jax.vmap(
772
lambda k: jr.normal(k, (784,))
773
)(keys)
774
775
# Random choice and permutation
776
data = jnp.arange(100)
777
shuffled = jr.permutation(key1, data)
778
selected = jr.choice(key2, data, (10,), replace=False)
779
780
# Multivariate distributions
781
mean = jnp.zeros(5)
782
cov = jnp.eye(5)
783
mv_samples = jr.multivariate_normal(key1, mean, cov, (1000,))
784
785
# Discrete distributions
786
coin_flips = jr.bernoulli(key1, 0.6, (100,))
787
dice_rolls = jr.categorical(key2, jnp.log(jnp.ones(6) / 6), (100,))
788
789
# Using in neural network initialization
790
def init_layer_weights(key, input_dim, output_dim):
791
w_key, b_key = jr.split(key)
792
# Xavier/Glorot initialization
793
std = jnp.sqrt(2.0 / (input_dim + output_dim))
794
weights = jr.normal(w_key, (input_dim, output_dim)) * std
795
biases = jr.normal(b_key, (output_dim,)) * 0.01
796
return weights, biases
797
798
# Stochastic gradient descent with random batching
799
def get_random_batch(key, data, batch_size):
800
indices = jr.choice(key, len(data), (batch_size,), replace=False)
801
return data[indices]
802
```