or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

numpy-compatibility.mddocs/

0

# NumPy Compatibility API

1

2

JAX provides a comprehensive NumPy-compatible API through `jax.numpy` (commonly imported as `jnp`). JAX arrays are immutable and support the full NumPy API with added benefits of JIT compilation, automatic differentiation, and device acceleration.

3

4

## Core Imports

5

6

```python

7

import jax.numpy as jnp

8

import jax

9

```

10

11

## Capabilities

12

13

### Array Creation

14

15

Create JAX arrays from various data sources and specifications.

16

17

```python { .api }

18

def array(object, dtype=None, copy=None, order=None, ndmin=0) -> Array:

19

"""Create array from array-like object."""

20

21

def asarray(a, dtype=None, order=None) -> Array:

22

"""Convert input to array."""

23

24

def zeros(shape, dtype=None) -> Array:

25

"""Create array filled with zeros."""

26

27

def zeros_like(a, dtype=None, shape=None) -> Array:

28

"""Create zeros array with same shape as input."""

29

30

def ones(shape, dtype=None) -> Array:

31

"""Create array filled with ones."""

32

33

def ones_like(a, dtype=None, shape=None) -> Array:

34

"""Create ones array with same shape as input."""

35

36

def full(shape, fill_value, dtype=None) -> Array:

37

"""Create array filled with constant value."""

38

39

def full_like(a, fill_value, dtype=None, shape=None) -> Array:

40

"""Create filled array with same shape as input."""

41

42

def empty(shape, dtype=None) -> Array:

43

"""Create uninitialized array."""

44

45

def empty_like(a, dtype=None, shape=None) -> Array:

46

"""Create empty array with same shape as input."""

47

48

def eye(N, M=None, k=0, dtype=None) -> Array:

49

"""Create identity matrix."""

50

51

def identity(n, dtype=None) -> Array:

52

"""Create square identity matrix."""

53

54

def arange(start, stop=None, step=None, dtype=None) -> Array:

55

"""Create evenly spaced values within interval."""

56

57

def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0) -> Array:

58

"""Create evenly spaced numbers over interval."""

59

60

def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0) -> Array:

61

"""Create numbers spaced evenly on log scale."""

62

63

def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0) -> Array:

64

"""Create numbers spaced evenly on log scale (geometric progression)."""

65

66

def meshgrid(*xi, copy=True, sparse=False, indexing='xy') -> list[Array]:

67

"""Create coordinate matrices from coordinate vectors."""

68

69

def mgrid() -> MGridClass:

70

"""Multi-dimensional mesh creation."""

71

72

def ogrid() -> OGridClass:

73

"""Open multi-dimensional mesh creation."""

74

75

def indices(dimensions, dtype=int, sparse=False) -> Array:

76

"""Create arrays of indices."""

77

78

def tri(N, M=None, k=0, dtype=None) -> Array:

79

"""Create array with ones at and below diagonal."""

80

```

81

82

### Mathematical Functions

83

84

Element-wise mathematical operations following NumPy conventions.

85

86

```python { .api }

87

# Arithmetic operations

88

def add(x1, x2) -> Array: ...

89

def subtract(x1, x2) -> Array: ...

90

def multiply(x1, x2) -> Array: ...

91

def divide(x1, x2) -> Array: ...

92

def true_divide(x1, x2) -> Array: ...

93

def floor_divide(x1, x2) -> Array: ...

94

def power(x1, x2) -> Array: ...

95

def float_power(x1, x2) -> Array: ...

96

def mod(x1, x2) -> Array: ...

97

def remainder(x1, x2) -> Array: ...

98

def divmod(x1, x2) -> tuple[Array, Array]: ...

99

100

# Trigonometric functions

101

def sin(x) -> Array: ...

102

def cos(x) -> Array: ...

103

def tan(x) -> Array: ...

104

def asin(x) -> Array: ...

105

def acos(x) -> Array: ...

106

def atan(x) -> Array: ...

107

def atan2(x1, x2) -> Array: ...

108

def sinh(x) -> Array: ...

109

def cosh(x) -> Array: ...

110

def tanh(x) -> Array: ...

111

def asinh(x) -> Array: ...

112

def acosh(x) -> Array: ...

113

def atanh(x) -> Array: ...

114

def degrees(x) -> Array: ...

115

def radians(x) -> Array: ...

116

def deg2rad(x) -> Array: ...

117

def rad2deg(x) -> Array: ...

118

119

# Exponential and logarithmic

120

def exp(x) -> Array: ...

121

def exp2(x) -> Array: ...

122

def expm1(x) -> Array: ...

123

def log(x) -> Array: ...

124

def log10(x) -> Array: ...

125

def log2(x) -> Array: ...

126

def log1p(x) -> Array: ...

127

128

# Rounding and precision

129

def round(a, decimals=0) -> Array: ...

130

def rint(x) -> Array: ...

131

def fix(x) -> Array: ...

132

def floor(x) -> Array: ...

133

def ceil(x) -> Array: ...

134

def trunc(x) -> Array: ...

135

136

# Arithmetic functions

137

def abs(x) -> Array: ...

138

def absolute(x) -> Array: ...

139

def fabs(x) -> Array: ...

140

def sign(x) -> Array: ...

141

def signbit(x) -> Array: ...

142

def copysign(x1, x2) -> Array: ...

143

def sqrt(x) -> Array: ...

144

def square(x) -> Array: ...

145

def cbrt(x) -> Array: ...

146

def reciprocal(x) -> Array: ...

147

def positive(x) -> Array: ...

148

def negative(x) -> Array: ...

149

150

# Extrema functions

151

def maximum(x1, x2) -> Array: ...

152

def minimum(x1, x2) -> Array: ...

153

def fmax(x1, x2) -> Array: ...

154

def fmin(x1, x2) -> Array: ...

155

def clip(a, a_min=None, a_max=None) -> Array: ...

156

157

# Complex number functions

158

def real(val) -> Array: ...

159

def imag(val) -> Array: ...

160

def conj(x) -> Array: ...

161

def conjugate(x) -> Array: ...

162

def angle(z, deg=False) -> Array: ...

163

def isreal(x) -> Array: ...

164

def iscomplex(x) -> Array: ...

165

166

# Floating point functions

167

def isfinite(x) -> Array: ...

168

def isinf(x) -> Array: ...

169

def isnan(x) -> Array: ...

170

def isneginf(x) -> Array: ...

171

def isposinf(x) -> Array: ...

172

def nextafter(x1, x2) -> Array: ...

173

def spacing(x) -> Array: ...

174

def modf(x) -> tuple[Array, Array]: ...

175

def frexp(x) -> tuple[Array, Array]: ...

176

def ldexp(x1, x2) -> Array: ...

177

```

178

179

### Array Manipulation

180

181

Functions for reshaping, combining, and transforming arrays.

182

183

```python { .api }

184

# Shape manipulation

185

def reshape(a, newshape, order='C') -> Array: ...

186

def ravel(a, order='C') -> Array: ...

187

def flatten(a, order='C') -> Array: ...

188

189

# Transpose operations

190

def transpose(a, axes=None) -> Array: ...

191

def swapaxes(a, axis1, axis2) -> Array: ...

192

def moveaxis(a, source, destination) -> Array: ...

193

def rollaxis(a, axis, start=0) -> Array: ...

194

195

# Dimension manipulation

196

def expand_dims(a, axis) -> Array: ...

197

def squeeze(a, axis=None) -> Array: ...

198

199

# Array reversal and rotation

200

def flip(m, axis=None) -> Array: ...

201

def fliplr(m) -> Array: ...

202

def flipud(m) -> Array: ...

203

def rot90(m, k=1, axes=(0, 1)) -> Array: ...

204

def roll(a, shift, axis=None) -> Array: ...

205

206

# Broadcasting

207

def broadcast_to(array, shape) -> Array: ...

208

def broadcast_arrays(*args) -> list[Array]: ...

209

210

# Joining arrays

211

def concatenate(arrays, axis=0) -> Array: ...

212

def stack(arrays, axis=0) -> Array: ...

213

def vstack(tup) -> Array: ...

214

def hstack(tup) -> Array: ...

215

def dstack(tup) -> Array: ...

216

def column_stack(tup) -> Array: ...

217

def append(arr, values, axis=None) -> Array: ...

218

219

# Splitting arrays

220

def split(ary, indices_or_sections, axis=0) -> list[Array]: ...

221

def array_split(ary, indices_or_sections, axis=0) -> list[Array]: ...

222

def hsplit(ary, indices_or_sections) -> list[Array]: ...

223

def vsplit(ary, indices_or_sections) -> list[Array]: ...

224

def dsplit(ary, indices_or_sections) -> list[Array]: ...

225

226

# Tiling and repeating

227

def tile(A, reps) -> Array: ...

228

def repeat(a, repeats, axis=None) -> Array: ...

229

230

# Array modification

231

def insert(arr, obj, values, axis=None) -> Array: ...

232

def delete(arr, obj, axis=None) -> Array: ...

233

def place(arr, mask, vals) -> None: ...

234

def put(a, ind, v, mode='raise') -> None: ...

235

def put_along_axis(arr, indices, values, axis) -> None: ...

236

237

def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, equal_nan=True) -> Array: ...

238

```

239

240

### Indexing and Selection

241

242

Advanced indexing, selection, and conditional operations.

243

244

```python { .api }

245

def take(a, indices, axis=None, mode=None) -> Array:

246

"""Take elements from array along axis."""

247

248

def take_along_axis(arr, indices, axis) -> Array:

249

"""Take values from array using indices along axis."""

250

251

def choose(a, choices, mode='raise') -> Array:

252

"""Construct array from index array and choice arrays."""

253

254

def compress(condition, a, axis=None) -> Array:

255

"""Return selected slices along axis."""

256

257

def extract(condition, arr) -> Array:

258

"""Return elements satisfying condition."""

259

260

def select(condlist, choicelist, default=0) -> Array:

261

"""Return elements chosen from choicelist based on conditions."""

262

263

def where(condition, x=None, y=None) -> Array:

264

"""Return elements chosen from x or y based on condition."""

265

266

def nonzero(a) -> tuple[Array, ...]:

267

"""Return indices of non-zero elements."""

268

269

def argwhere(a) -> Array:

270

"""Return indices where condition is True."""

271

272

def flatnonzero(a) -> Array:

273

"""Return indices of flattened array that are non-zero."""

274

275

def ix_(*args) -> tuple[Array, ...]:

276

"""Construct open mesh from multiple sequences."""

277

```

278

279

### Reduction Operations

280

281

Functions that reduce arrays along axes or compute aggregates.

282

283

```python { .api }

284

# Basic reductions

285

def sum(a, axis=None, dtype=None, keepdims=False, initial=None, where=None) -> Array: ...

286

def prod(a, axis=None, dtype=None, keepdims=False, initial=None, where=None) -> Array: ...

287

def mean(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...

288

def median(a, axis=None, keepdims=False) -> Array: ...

289

def std(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...

290

def var(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...

291

292

# Extrema

293

def min(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...

294

def max(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...

295

def amin(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...

296

def amax(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...

297

def ptp(a, axis=None, keepdims=False) -> Array: ...

298

299

# Percentiles and quantiles

300

def percentile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...

301

def quantile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...

302

303

# Cumulative operations

304

def cumsum(a, axis=None, dtype=None) -> Array: ...

305

def cumprod(a, axis=None, dtype=None) -> Array: ...

306

307

# Logical reductions

308

def all(a, axis=None, keepdims=False, where=None) -> Array: ...

309

def any(a, axis=None, keepdims=False, where=None) -> Array: ...

310

311

# Counting

312

def count_nonzero(a, axis=None, keepdims=False) -> Array: ...

313

314

# NaN-aware reductions

315

def nansum(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...

316

def nanprod(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...

317

def nanmean(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...

318

def nanmedian(a, axis=None, keepdims=False) -> Array: ...

319

def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...

320

def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...

321

def nanmin(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...

322

def nanmax(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...

323

def nanpercentile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...

324

def nanquantile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...

325

def nancumsum(a, axis=None, dtype=None) -> Array: ...

326

def nancumprod(a, axis=None, dtype=None) -> Array: ...

327

328

# Indices of extrema

329

def argmin(a, axis=None, keepdims=False) -> Array: ...

330

def argmax(a, axis=None, keepdims=False) -> Array: ...

331

def nanargmin(a, axis=None, keepdims=False) -> Array: ...

332

def nanargmax(a, axis=None, keepdims=False) -> Array: ...

333

```

334

335

### Linear Algebra

336

337

Core linear algebra operations for matrix computations.

338

339

```python { .api }

340

# Matrix multiplication

341

def dot(a, b) -> Array: ...

342

def matmul(x1, x2) -> Array: ...

343

def inner(a, b) -> Array: ...

344

def outer(a, b) -> Array: ...

345

def tensordot(a, b, axes=2) -> Array: ...

346

def kron(a, b) -> Array: ...

347

348

# Vector operations

349

def vdot(a, b) -> Array: ...

350

def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None) -> Array: ...

351

352

# Matrix operations

353

def trace(a, offset=0, axis1=0, axis2=1, dtype=None) -> Array: ...

354

def diagonal(a, offset=0, axis1=0, axis2=1) -> Array: ...

355

def diag(v, k=0) -> Array: ...

356

def diagflat(v, k=0) -> Array: ...

357

358

# Triangular matrices

359

def tril(m, k=0) -> Array: ...

360

def triu(m, k=0) -> Array: ...

361

def tril_indices(n, k=0, m=None) -> tuple[Array, Array]: ...

362

def triu_indices(n, k=0, m=None) -> tuple[Array, Array]: ...

363

def diag_indices(n, ndim=2) -> tuple[Array, ...]: ...

364

365

# Matrix transpose

366

def matrix_transpose(x) -> Array: ...

367

```

368

369

### Sorting and Searching

370

371

Functions for sorting arrays and searching for values.

372

373

```python { .api }

374

def sort(a, axis=-1, kind='stable', order=None) -> Array: ...

375

def argsort(a, axis=-1, kind='stable', order=None) -> Array: ...

376

def lexsort(keys, axis=-1) -> Array: ...

377

def partition(a, kth, axis=-1, kind='introselect', order=None) -> Array: ...

378

def argpartition(a, kth, axis=-1, kind='introselect', order=None) -> Array: ...

379

def searchsorted(a, v, side='left', sorter=None) -> Array: ...

380

def sort_complex(a) -> Array: ...

381

```

382

383

### Set Operations

384

385

Set-like operations on arrays.

386

387

```python { .api }

388

def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None) -> Array: ...

389

def intersect1d(ar1, ar2, assume_unique=False, return_indices=False) -> Array: ...

390

def union1d(ar1, ar2) -> Array: ...

391

def setdiff1d(ar1, ar2, assume_unique=False) -> Array: ...

392

def setxor1d(ar1, ar2, assume_unique=False) -> Array: ...

393

def isin(element, test_elements, assume_unique=False, invert=False) -> Array: ...

394

```

395

396

### Statistical Functions

397

398

Statistical analysis and distribution functions.

399

400

```python { .api }

401

def bincount(x, weights=None, minlength=0, length=None) -> Array: ...

402

def histogram(a, bins=10, range=None, weights=None, density=None) -> tuple[Array, Array]: ...

403

def histogram2d(x, y, bins=10, range=None, weights=None, density=None) -> tuple[Array, Array, Array]: ...

404

def histogramdd(sample, bins=10, range=None, weights=None, density=None) -> tuple[Array, list[Array]]: ...

405

def histogram_bin_edges(a, bins=10, range=None, weights=None) -> Array: ...

406

def digitize(x, bins, right=False) -> Array: ...

407

def average(a, axis=None, weights=None, returned=False, keepdims=False) -> Array: ...

408

def corrcoef(x, y=None, rowvar=True, dtype=None) -> Array: ...

409

def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None) -> Array: ...

410

def gradient(f, *varargs, axis=None, edge_order=1) -> Array: ...

411

```

412

413

### Data Types and Conversion

414

415

Type information, checking, and conversion functions.

416

417

```python { .api }

418

# Type checking

419

def issubdtype(arg1, arg2) -> bool: ...

420

def can_cast(from_, to, casting='safe') -> bool: ...

421

def result_type(*arrays_and_dtypes): ...

422

def promote_types(type1, type2): ...

423

def isscalar(element) -> bool: ...

424

def isrealobj(x) -> bool: ...

425

def iscomplexobj(x) -> bool: ...

426

427

# Type information

428

def finfo(dtype): ...

429

def iinfo(dtype): ...

430

431

# Array properties

432

def ndim(a) -> int: ...

433

def shape(a) -> tuple: ...

434

def size(a) -> int: ...

435

436

# Comparison functions

437

def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool: ...

438

def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) -> Array: ...

439

def array_equal(a1, a2, equal_nan=False) -> bool: ...

440

def array_equiv(a1, a2) -> bool: ...

441

442

# Utility functions

443

def copy(a, order='K') -> Array: ...

444

def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None) -> Array: ...

445

```

446

447

### Comparison Operations

448

449

Element-wise comparison functions returning boolean arrays.

450

451

```python { .api }

452

def equal(x1, x2) -> Array: ...

453

def not_equal(x1, x2) -> Array: ...

454

def less(x1, x2) -> Array: ...

455

def less_equal(x1, x2) -> Array: ...

456

def greater(x1, x2) -> Array: ...

457

def greater_equal(x1, x2) -> Array: ...

458

```

459

460

### Logical Operations

461

462

Element-wise logical operations on boolean arrays.

463

464

```python { .api }

465

def logical_and(x1, x2) -> Array: ...

466

def logical_or(x1, x2) -> Array: ...

467

def logical_not(x) -> Array: ...

468

def logical_xor(x1, x2) -> Array: ...

469

```

470

471

### Bitwise Operations

472

473

Element-wise bitwise operations on integer arrays.

474

475

```python { .api }

476

def bitwise_and(x1, x2) -> Array: ...

477

def bitwise_or(x1, x2) -> Array: ...

478

def bitwise_xor(x1, x2) -> Array: ...

479

def bitwise_not(x) -> Array: ...

480

def bitwise_left_shift(x1, x2) -> Array: ...

481

def bitwise_right_shift(x1, x2) -> Array: ...

482

def left_shift(x1, x2) -> Array: ...

483

def right_shift(x1, x2) -> Array: ...

484

def invert(x) -> Array: ...

485

def bitwise_count(x) -> Array: ...

486

```

487

488

### Constants and Special Values

489

490

Mathematical and numerical constants.

491

492

```python { .api }

493

pi: float # π (3.14159...)

494

e: float # Euler's number (2.71828...)

495

euler_gamma: float # Euler-Mascheroni constant

496

inf: float # Positive infinity

497

nan: float # Not a Number

498

newaxis: None # Used for adding dimensions in indexing

499

```

500

501

## NumPy Submodules

502

503

### FFT Operations

504

505

```python { .api }

506

import jax.numpy.fft as jfft

507

508

# 1D transforms

509

jfft.fft(a, n=None, axis=-1, norm=None) -> Array

510

jfft.ifft(a, n=None, axis=-1, norm=None) -> Array

511

jfft.rfft(a, n=None, axis=-1, norm=None) -> Array

512

jfft.irfft(a, n=None, axis=-1, norm=None) -> Array

513

514

# 2D transforms

515

jfft.fft2(a, s=None, axes=(-2, -1), norm=None) -> Array

516

jfft.ifft2(a, s=None, axes=(-2, -1), norm=None) -> Array

517

jfft.rfft2(a, s=None, axes=(-2, -1), norm=None) -> Array

518

jfft.irfft2(a, s=None, axes=(-2, -1), norm=None) -> Array

519

520

# N-D transforms

521

jfft.fftn(a, s=None, axes=None, norm=None) -> Array

522

jfft.ifftn(a, s=None, axes=None, norm=None) -> Array

523

jfft.rfftn(a, s=None, axes=None, norm=None) -> Array

524

jfft.irfftn(a, s=None, axes=None, norm=None) -> Array

525

526

# Helper functions

527

jfft.fftfreq(n, d=1.0) -> Array

528

jfft.rfftfreq(n, d=1.0) -> Array

529

jfft.fftshift(x, axes=None) -> Array

530

jfft.ifftshift(x, axes=None) -> Array

531

```

532

533

### Linear Algebra Operations

534

535

```python { .api }

536

import jax.numpy.linalg as jla

537

538

# Matrix decompositions

539

jla.cholesky(a) -> Array

540

jla.qr(a, mode='reduced') -> tuple[Array, Array]

541

jla.svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]

542

jla.eig(a) -> tuple[Array, Array]

543

jla.eigh(a, UPLO='L') -> tuple[Array, Array]

544

jla.eigvals(a) -> Array

545

jla.eigvalsh(a, UPLO='L') -> Array

546

547

# Matrix properties

548

jla.det(a) -> Array

549

jla.slogdet(a) -> tuple[Array, Array]

550

jla.matrix_rank(M, tol=None, hermitian=False) -> Array

551

jla.trace(a, offset=0, axis1=0, axis2=1, dtype=None) -> Array

552

553

# Matrix solutions

554

jla.solve(a, b) -> Array

555

jla.lstsq(a, b, rcond=None) -> tuple[Array, Array, Array, Array]

556

jla.inv(a) -> Array

557

jla.pinv(a, rcond=None, hermitian=False) -> Array

558

559

# Norms and distances

560

jla.norm(x, ord=None, axis=None, keepdims=False) -> Array

561

jla.cond(x, p=None) -> Array

562

563

# Matrix functions

564

jla.matrix_power(a, n) -> Array

565

```