or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

random-numbers.mddocs/

0

# Random Number Generation

1

2

JAX uses a functional approach to pseudo-random number generation with explicit key management. This design enables reproducibility, parallelization, and vectorization while avoiding global state typical of other libraries.

3

4

## Core Imports

5

6

```python

7

import jax.random as jr

8

from jax.random import key, split, normal, uniform

9

```

10

11

## Key Concepts

12

13

JAX random numbers require explicit key management:

14

- Keys are created from integer seeds

15

- Keys must be split to generate independent random sequences

16

- Each random function consumes a key and returns deterministic output

17

- No global random state - all randomness is explicit

18

19

## Capabilities

20

21

### Key Management

22

23

Generate, split, and manipulate PRNG keys for deterministic random number generation.

24

25

```python { .api }

26

def key(seed: int, impl=None) -> Array:

27

"""

28

Create a typed PRNG key from integer seed.

29

30

Args:

31

seed: Integer seed value

32

impl: PRNG implementation to use

33

34

Returns:

35

PRNG key array

36

"""

37

38

def PRNGKey(seed: int) -> Array:

39

"""

40

Create legacy PRNG key (uint32 format).

41

42

Args:

43

seed: Integer seed value

44

45

Returns:

46

Legacy format PRNG key

47

"""

48

49

def split(key: Array, num: int = 2) -> Array:

50

"""

51

Split PRNG key into multiple independent keys.

52

53

Args:

54

key: PRNG key to split

55

num: Number of keys to generate (default: 2)

56

57

Returns:

58

Array of shape (num,) + key.shape containing new keys

59

"""

60

61

def fold_in(key: Array, data: int) -> Array:

62

"""

63

Fold integer data into PRNG key.

64

65

Args:

66

key: PRNG key

67

data: Integer to fold into key

68

69

Returns:

70

New PRNG key with data folded in

71

"""

72

73

def clone(key: Array) -> Array:

74

"""

75

Clone PRNG key for reuse.

76

77

Args:

78

key: PRNG key to clone

79

80

Returns:

81

Cloned PRNG key

82

"""

83

84

def key_data(keys: Array) -> Array:

85

"""

86

Extract raw key data from PRNG keys.

87

88

Args:

89

keys: PRNG key array

90

91

Returns:

92

Raw key data

93

"""

94

95

def wrap_key_data(key_data: Array, *, impl=None) -> Array:

96

"""

97

Wrap raw key data as PRNG keys.

98

99

Args:

100

key_data: Raw key data

101

impl: PRNG implementation

102

103

Returns:

104

PRNG key array

105

"""

106

107

def key_impl(key: Array) -> str:

108

"""

109

Get PRNG implementation name for key.

110

111

Args:

112

key: PRNG key

113

114

Returns:

115

Implementation name string

116

"""

117

```

118

119

### Continuous Distributions

120

121

Sample from continuous probability distributions.

122

123

```python { .api }

124

def uniform(

125

key: Array,

126

shape=(),

127

dtype=float,

128

minval=0.0,

129

maxval=1.0

130

) -> Array:

131

"""

132

Sample from uniform distribution.

133

134

Args:

135

key: PRNG key

136

shape: Output shape

137

dtype: Output data type

138

minval: Minimum value (inclusive)

139

maxval: Maximum value (exclusive)

140

141

Returns:

142

Random samples from uniform distribution

143

"""

144

145

def normal(key: Array, shape=(), dtype=float) -> Array:

146

"""

147

Sample from standard normal (Gaussian) distribution.

148

149

Args:

150

key: PRNG key

151

shape: Output shape

152

dtype: Output data type

153

154

Returns:

155

Random samples from N(0, 1)

156

"""

157

158

def multivariate_normal(

159

key: Array,

160

mean: Array,

161

cov: Array,

162

shape=(),

163

dtype=float,

164

method='cholesky'

165

) -> Array:

166

"""

167

Sample from multivariate normal distribution.

168

169

Args:

170

key: PRNG key

171

mean: Mean vector

172

cov: Covariance matrix

173

shape: Batch shape

174

dtype: Output data type

175

method: Decomposition method ('cholesky', 'eigh', 'svd')

176

177

Returns:

178

Random samples from multivariate normal

179

"""

180

181

def truncated_normal(

182

key: Array,

183

lower: float,

184

upper: float,

185

shape=(),

186

dtype=float

187

) -> Array:

188

"""

189

Sample from truncated normal distribution.

190

191

Args:

192

key: PRNG key

193

lower: Lower truncation bound

194

upper: Upper truncation bound

195

shape: Output shape

196

dtype: Output data type

197

198

Returns:

199

Random samples from truncated normal

200

"""

201

202

def beta(key: Array, a: Array, b: Array, shape=(), dtype=float) -> Array:

203

"""

204

Sample from beta distribution.

205

206

Args:

207

key: PRNG key

208

a: Alpha parameter (concentration)

209

b: Beta parameter (concentration)

210

shape: Output shape

211

dtype: Output data type

212

213

Returns:

214

Random samples from Beta(a, b)

215

"""

216

217

def gamma(key: Array, a: Array, shape=(), dtype=float) -> Array:

218

"""

219

Sample from gamma distribution.

220

221

Args:

222

key: PRNG key

223

a: Shape parameter

224

shape: Output shape

225

dtype: Output data type

226

227

Returns:

228

Random samples from Gamma(a, 1)

229

"""

230

231

def exponential(key: Array, shape=(), dtype=float) -> Array:

232

"""

233

Sample from exponential distribution.

234

235

Args:

236

key: PRNG key

237

shape: Output shape

238

dtype: Output data type

239

240

Returns:

241

Random samples from Exponential(1)

242

"""

243

244

def laplace(key: Array, shape=(), dtype=float) -> Array:

245

"""

246

Sample from Laplace distribution.

247

248

Args:

249

key: PRNG key

250

shape: Output shape

251

dtype: Output data type

252

253

Returns:

254

Random samples from Laplace(0, 1)

255

"""

256

257

def logistic(key: Array, shape=(), dtype=float) -> Array:

258

"""

259

Sample from logistic distribution.

260

261

Args:

262

key: PRNG key

263

shape: Output shape

264

dtype: Output data type

265

266

Returns:

267

Random samples from Logistic(0, 1)

268

"""

269

270

def lognormal(key: Array, sigma=1.0, shape=(), dtype=float) -> Array:

271

"""

272

Sample from log-normal distribution.

273

274

Args:

275

key: PRNG key

276

sigma: Standard deviation of underlying normal

277

shape: Output shape

278

dtype: Output data type

279

280

Returns:

281

Random samples from log-normal distribution

282

"""

283

284

def pareto(key: Array, b: Array, shape=(), dtype=float) -> Array:

285

"""

286

Sample from Pareto distribution.

287

288

Args:

289

key: PRNG key

290

b: Shape parameter

291

shape: Output shape

292

dtype: Output data type

293

294

Returns:

295

Random samples from Pareto(b, 1)

296

"""

297

298

def cauchy(key: Array, shape=(), dtype=float) -> Array:

299

"""

300

Sample from Cauchy distribution.

301

302

Args:

303

key: PRNG key

304

shape: Output shape

305

dtype: Output data type

306

307

Returns:

308

Random samples from Cauchy(0, 1)

309

"""

310

311

def double_sided_maxwell(

312

key: Array,

313

loc: Array,

314

scale: Array,

315

shape=(),

316

dtype=float

317

) -> Array:

318

"""

319

Sample from double-sided Maxwell distribution.

320

321

Args:

322

key: PRNG key

323

loc: Location parameter

324

scale: Scale parameter

325

shape: Output shape

326

dtype: Output data type

327

328

Returns:

329

Random samples from double-sided Maxwell

330

"""

331

332

def maxwell(key: Array, shape=(), dtype=float) -> Array:

333

"""

334

Sample from Maxwell distribution.

335

336

Args:

337

key: PRNG key

338

shape: Output shape

339

dtype: Output data type

340

341

Returns:

342

Random samples from Maxwell distribution

343

"""

344

345

def rayleigh(key: Array, scale=1.0, shape=(), dtype=float) -> Array:

346

"""

347

Sample from Rayleigh distribution.

348

349

Args:

350

key: PRNG key

351

scale: Scale parameter

352

shape: Output shape

353

dtype: Output data type

354

355

Returns:

356

Random samples from Rayleigh(scale)

357

"""

358

359

def wald(key: Array, mean: Array, shape=(), dtype=float) -> Array:

360

"""

361

Sample from Wald (Inverse Gaussian) distribution.

362

363

Args:

364

key: PRNG key

365

mean: Mean parameter

366

shape: Output shape

367

dtype: Output data type

368

369

Returns:

370

Random samples from Wald distribution

371

"""

372

373

def weibull_min(

374

key: Array,

375

concentration: Array,

376

scale=1.0,

377

shape=(),

378

dtype=float

379

) -> Array:

380

"""

381

Sample from Weibull minimum distribution.

382

383

Args:

384

key: PRNG key

385

concentration: Shape parameter

386

scale: Scale parameter

387

shape: Output shape

388

dtype: Output data type

389

390

Returns:

391

Random samples from Weibull minimum

392

"""

393

394

def gumbel(key: Array, shape=(), dtype=float) -> Array:

395

"""

396

Sample from Gumbel distribution.

397

398

Args:

399

key: PRNG key

400

shape: Output shape

401

dtype: Output data type

402

403

Returns:

404

Random samples from Gumbel(0, 1)

405

"""

406

407

def chisquare(key: Array, df: Array, shape=(), dtype=float) -> Array:

408

"""

409

Sample from chi-square distribution.

410

411

Args:

412

key: PRNG key

413

df: Degrees of freedom

414

shape: Output shape

415

dtype: Output data type

416

417

Returns:

418

Random samples from chi-square(df)

419

"""

420

421

def dirichlet(

422

key: Array,

423

alpha: Array,

424

shape=(),

425

dtype=float

426

) -> Array:

427

"""

428

Sample from Dirichlet distribution.

429

430

Args:

431

key: PRNG key

432

alpha: Concentration parameters

433

shape: Batch shape

434

dtype: Output data type

435

436

Returns:

437

Random samples from Dirichlet(alpha)

438

"""

439

440

def f(key: Array, dfnum: Array, dfden: Array, shape=(), dtype=float) -> Array:

441

"""

442

Sample from F-distribution.

443

444

Args:

445

key: PRNG key

446

dfnum: Numerator degrees of freedom

447

dfden: Denominator degrees of freedom

448

shape: Output shape

449

dtype: Output data type

450

451

Returns:

452

Random samples from F-distribution

453

"""

454

455

def t(key: Array, df: Array, shape=(), dtype=float) -> Array:

456

"""

457

Sample from Student's t-distribution.

458

459

Args:

460

key: PRNG key

461

df: Degrees of freedom

462

shape: Output shape

463

dtype: Output data type

464

465

Returns:

466

Random samples from t-distribution

467

"""

468

469

def triangular(

470

key: Array,

471

left: Array,

472

mode: Array,

473

right: Array,

474

shape=(),

475

dtype=float

476

) -> Array:

477

"""

478

Sample from triangular distribution.

479

480

Args:

481

key: PRNG key

482

left: Left boundary

483

mode: Mode (peak) value

484

right: Right boundary

485

shape: Output shape

486

dtype: Output data type

487

488

Returns:

489

Random samples from triangular distribution

490

"""

491

492

def generalized_normal(

493

key: Array,

494

p: Array,

495

shape=(),

496

dtype=float

497

) -> Array:

498

"""

499

Sample from generalized normal distribution.

500

501

Args:

502

key: PRNG key

503

p: Shape parameter

504

shape: Output shape

505

dtype: Output data type

506

507

Returns:

508

Random samples from generalized normal

509

"""

510

511

def loggamma(key: Array, a: Array, shape=(), dtype=float) -> Array:

512

"""

513

Sample log-gamma random variables.

514

515

Args:

516

key: PRNG key

517

a: Shape parameter

518

shape: Output shape

519

dtype: Output data type

520

521

Returns:

522

Random samples from log-gamma distribution

523

"""

524

```

525

526

### Discrete Distributions

527

528

Sample from discrete probability distributions.

529

530

```python { .api }

531

def bernoulli(key: Array, p=0.5, shape=(), dtype=int) -> Array:

532

"""

533

Sample from Bernoulli distribution.

534

535

Args:

536

key: PRNG key

537

p: Success probability

538

shape: Output shape

539

dtype: Output data type

540

541

Returns:

542

Random samples from Bernoulli(p)

543

"""

544

545

def binomial(key: Array, n: Array, p: Array, shape=(), dtype=int) -> Array:

546

"""

547

Sample from binomial distribution.

548

549

Args:

550

key: PRNG key

551

n: Number of trials

552

p: Success probability per trial

553

shape: Output shape

554

dtype: Output data type

555

556

Returns:

557

Random samples from Binomial(n, p)

558

"""

559

560

def categorical(

561

key: Array,

562

logits: Array,

563

axis=-1,

564

shape=None

565

) -> Array:

566

"""

567

Sample from categorical distribution.

568

569

Args:

570

key: PRNG key

571

logits: Log-probability array

572

axis: Axis over which to normalize

573

shape: Output shape

574

575

Returns:

576

Random categorical indices

577

"""

578

579

def choice(

580

key: Array,

581

a: int | Array,

582

shape=(),

583

replace=True,

584

p=None,

585

axis=0

586

) -> Array:

587

"""

588

Random choice from array elements.

589

590

Args:

591

key: PRNG key

592

a: Array to sample from or integer (range)

593

shape: Output shape

594

replace: Whether to sample with replacement

595

p: Probabilities for each element

596

axis: Axis to sample along

597

598

Returns:

599

Random samples from input array

600

"""

601

602

def geometric(key: Array, p: Array, shape=(), dtype=int) -> Array:

603

"""

604

Sample from geometric distribution.

605

606

Args:

607

key: PRNG key

608

p: Success probability

609

shape: Output shape

610

dtype: Output data type

611

612

Returns:

613

Random samples from Geometric(p)

614

"""

615

616

def poisson(key: Array, lam: Array, shape=(), dtype=int) -> Array:

617

"""

618

Sample from Poisson distribution.

619

620

Args:

621

key: PRNG key

622

lam: Rate parameter

623

shape: Output shape

624

dtype: Output data type

625

626

Returns:

627

Random samples from Poisson(lam)

628

"""

629

630

def multinomial(

631

key: Array,

632

n: Array,

633

pvals: Array,

634

shape=(),

635

dtype=int

636

) -> Array:

637

"""

638

Sample from multinomial distribution.

639

640

Args:

641

key: PRNG key

642

n: Number of trials

643

pvals: Probability values for each category

644

shape: Batch shape

645

dtype: Output data type

646

647

Returns:

648

Random samples from Multinomial(n, pvals)

649

"""

650

651

def randint(

652

key: Array,

653

minval: int,

654

maxval: int,

655

shape=(),

656

dtype=int

657

) -> Array:

658

"""

659

Sample random integers from [minval, maxval).

660

661

Args:

662

key: PRNG key

663

minval: Minimum value (inclusive)

664

maxval: Maximum value (exclusive)

665

shape: Output shape

666

dtype: Output data type

667

668

Returns:

669

Random integers in specified range

670

"""

671

672

def rademacher(key: Array, shape=(), dtype=int) -> Array:

673

"""

674

Sample from Rademacher distribution (±1 with equal probability).

675

676

Args:

677

key: PRNG key

678

shape: Output shape

679

dtype: Output data type

680

681

Returns:

682

Random samples from {-1, +1}

683

"""

684

```

685

686

### Specialized Sampling

687

688

Special sampling functions for geometric shapes and structured sampling.

689

690

```python { .api }

691

def ball(key: Array, d: int, p=2, shape=(), dtype=float) -> Array:

692

"""

693

Sample uniformly from d-dimensional unit ball.

694

695

Args:

696

key: PRNG key

697

d: Dimension of ball

698

p: Norm type (default: 2 for Euclidean)

699

shape: Batch shape

700

dtype: Output data type

701

702

Returns:

703

Random samples from unit ball

704

"""

705

706

def orthogonal(key: Array, n: int, shape=(), dtype=float) -> Array:

707

"""

708

Sample random orthogonal matrix.

709

710

Args:

711

key: PRNG key

712

n: Matrix dimension

713

shape: Batch shape

714

dtype: Output data type

715

716

Returns:

717

Random orthogonal matrix of size (n, n)

718

"""

719

720

def permutation(key: Array, x: int | Array, axis=0, independent=False) -> Array:

721

"""

722

Generate random permutation of array or integers.

723

724

Args:

725

key: PRNG key

726

x: Array to permute or integer (range)

727

axis: Axis to permute along

728

independent: Whether to permute each batch element independently

729

730

Returns:

731

Randomly permuted array

732

"""

733

734

def bits(key: Array, width=64, shape=(), dtype=None) -> Array:

735

"""

736

Generate random bits.

737

738

Args:

739

key: PRNG key

740

width: Number of bits per sample

741

shape: Output shape

742

dtype: Output data type

743

744

Returns:

745

Random bit patterns

746

"""

747

```

748

749

## Usage Examples

750

751

Common patterns for JAX random number generation:

752

753

```python

754

import jax

755

import jax.numpy as jnp

756

import jax.random as jr

757

758

# Create and split keys

759

main_key = jr.key(42)

760

key1, key2, key3 = jr.split(main_key, 3)

761

762

# Basic sampling

763

samples = jr.normal(key1, (1000,))

764

random_ints = jr.randint(key2, 0, 10, (100,))

765

766

# Batch sampling with same key

767

batch_samples = jr.normal(key3, (32, 784)) # 32 samples of 784 dims

768

769

# Different keys for each batch element

770

keys = jr.split(main_key, 32)

771

independent_samples = jax.vmap(

772

lambda k: jr.normal(k, (784,))

773

)(keys)

774

775

# Random choice and permutation

776

data = jnp.arange(100)

777

shuffled = jr.permutation(key1, data)

778

selected = jr.choice(key2, data, (10,), replace=False)

779

780

# Multivariate distributions

781

mean = jnp.zeros(5)

782

cov = jnp.eye(5)

783

mv_samples = jr.multivariate_normal(key1, mean, cov, (1000,))

784

785

# Discrete distributions

786

coin_flips = jr.bernoulli(key1, 0.6, (100,))

787

dice_rolls = jr.categorical(key2, jnp.log(jnp.ones(6) / 6), (100,))

788

789

# Using in neural network initialization

790

def init_layer_weights(key, input_dim, output_dim):

791

w_key, b_key = jr.split(key)

792

# Xavier/Glorot initialization

793

std = jnp.sqrt(2.0 / (input_dim + output_dim))

794

weights = jr.normal(w_key, (input_dim, output_dim)) * std

795

biases = jr.normal(b_key, (output_dim,)) * 0.01

796

return weights, biases

797

798

# Stochastic gradient descent with random batching

799

def get_random_batch(key, data, batch_size):

800

indices = jr.choice(key, len(data), (batch_size,), replace=False)

801

return data[indices]

802

```