0
# NumPy Compatibility API
1
2
JAX provides a comprehensive NumPy-compatible API through `jax.numpy` (commonly imported as `jnp`). JAX arrays are immutable and support the full NumPy API with added benefits of JIT compilation, automatic differentiation, and device acceleration.
3
4
## Core Imports
5
6
```python
7
import jax.numpy as jnp
8
import jax
9
```
10
11
## Capabilities
12
13
### Array Creation
14
15
Create JAX arrays from various data sources and specifications.
16
17
```python { .api }
18
def array(object, dtype=None, copy=None, order=None, ndmin=0) -> Array:
19
"""Create array from array-like object."""
20
21
def asarray(a, dtype=None, order=None) -> Array:
22
"""Convert input to array."""
23
24
def zeros(shape, dtype=None) -> Array:
25
"""Create array filled with zeros."""
26
27
def zeros_like(a, dtype=None, shape=None) -> Array:
28
"""Create zeros array with same shape as input."""
29
30
def ones(shape, dtype=None) -> Array:
31
"""Create array filled with ones."""
32
33
def ones_like(a, dtype=None, shape=None) -> Array:
34
"""Create ones array with same shape as input."""
35
36
def full(shape, fill_value, dtype=None) -> Array:
37
"""Create array filled with constant value."""
38
39
def full_like(a, fill_value, dtype=None, shape=None) -> Array:
40
"""Create filled array with same shape as input."""
41
42
def empty(shape, dtype=None) -> Array:
43
"""Create uninitialized array."""
44
45
def empty_like(a, dtype=None, shape=None) -> Array:
46
"""Create empty array with same shape as input."""
47
48
def eye(N, M=None, k=0, dtype=None) -> Array:
49
"""Create identity matrix."""
50
51
def identity(n, dtype=None) -> Array:
52
"""Create square identity matrix."""
53
54
def arange(start, stop=None, step=None, dtype=None) -> Array:
55
"""Create evenly spaced values within interval."""
56
57
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0) -> Array:
58
"""Create evenly spaced numbers over interval."""
59
60
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0) -> Array:
61
"""Create numbers spaced evenly on log scale."""
62
63
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0) -> Array:
64
"""Create numbers spaced evenly on log scale (geometric progression)."""
65
66
def meshgrid(*xi, copy=True, sparse=False, indexing='xy') -> list[Array]:
67
"""Create coordinate matrices from coordinate vectors."""
68
69
def mgrid() -> MGridClass:
70
"""Multi-dimensional mesh creation."""
71
72
def ogrid() -> OGridClass:
73
"""Open multi-dimensional mesh creation."""
74
75
def indices(dimensions, dtype=int, sparse=False) -> Array:
76
"""Create arrays of indices."""
77
78
def tri(N, M=None, k=0, dtype=None) -> Array:
79
"""Create array with ones at and below diagonal."""
80
```
81
82
### Mathematical Functions
83
84
Element-wise mathematical operations following NumPy conventions.
85
86
```python { .api }
87
# Arithmetic operations
88
def add(x1, x2) -> Array: ...
89
def subtract(x1, x2) -> Array: ...
90
def multiply(x1, x2) -> Array: ...
91
def divide(x1, x2) -> Array: ...
92
def true_divide(x1, x2) -> Array: ...
93
def floor_divide(x1, x2) -> Array: ...
94
def power(x1, x2) -> Array: ...
95
def float_power(x1, x2) -> Array: ...
96
def mod(x1, x2) -> Array: ...
97
def remainder(x1, x2) -> Array: ...
98
def divmod(x1, x2) -> tuple[Array, Array]: ...
99
100
# Trigonometric functions
101
def sin(x) -> Array: ...
102
def cos(x) -> Array: ...
103
def tan(x) -> Array: ...
104
def asin(x) -> Array: ...
105
def acos(x) -> Array: ...
106
def atan(x) -> Array: ...
107
def atan2(x1, x2) -> Array: ...
108
def sinh(x) -> Array: ...
109
def cosh(x) -> Array: ...
110
def tanh(x) -> Array: ...
111
def asinh(x) -> Array: ...
112
def acosh(x) -> Array: ...
113
def atanh(x) -> Array: ...
114
def degrees(x) -> Array: ...
115
def radians(x) -> Array: ...
116
def deg2rad(x) -> Array: ...
117
def rad2deg(x) -> Array: ...
118
119
# Exponential and logarithmic
120
def exp(x) -> Array: ...
121
def exp2(x) -> Array: ...
122
def expm1(x) -> Array: ...
123
def log(x) -> Array: ...
124
def log10(x) -> Array: ...
125
def log2(x) -> Array: ...
126
def log1p(x) -> Array: ...
127
128
# Rounding and precision
129
def round(a, decimals=0) -> Array: ...
130
def rint(x) -> Array: ...
131
def fix(x) -> Array: ...
132
def floor(x) -> Array: ...
133
def ceil(x) -> Array: ...
134
def trunc(x) -> Array: ...
135
136
# Arithmetic functions
137
def abs(x) -> Array: ...
138
def absolute(x) -> Array: ...
139
def fabs(x) -> Array: ...
140
def sign(x) -> Array: ...
141
def signbit(x) -> Array: ...
142
def copysign(x1, x2) -> Array: ...
143
def sqrt(x) -> Array: ...
144
def square(x) -> Array: ...
145
def cbrt(x) -> Array: ...
146
def reciprocal(x) -> Array: ...
147
def positive(x) -> Array: ...
148
def negative(x) -> Array: ...
149
150
# Extrema functions
151
def maximum(x1, x2) -> Array: ...
152
def minimum(x1, x2) -> Array: ...
153
def fmax(x1, x2) -> Array: ...
154
def fmin(x1, x2) -> Array: ...
155
def clip(a, a_min=None, a_max=None) -> Array: ...
156
157
# Complex number functions
158
def real(val) -> Array: ...
159
def imag(val) -> Array: ...
160
def conj(x) -> Array: ...
161
def conjugate(x) -> Array: ...
162
def angle(z, deg=False) -> Array: ...
163
def isreal(x) -> Array: ...
164
def iscomplex(x) -> Array: ...
165
166
# Floating point functions
167
def isfinite(x) -> Array: ...
168
def isinf(x) -> Array: ...
169
def isnan(x) -> Array: ...
170
def isneginf(x) -> Array: ...
171
def isposinf(x) -> Array: ...
172
def nextafter(x1, x2) -> Array: ...
173
def spacing(x) -> Array: ...
174
def modf(x) -> tuple[Array, Array]: ...
175
def frexp(x) -> tuple[Array, Array]: ...
176
def ldexp(x1, x2) -> Array: ...
177
```
178
179
### Array Manipulation
180
181
Functions for reshaping, combining, and transforming arrays.
182
183
```python { .api }
184
# Shape manipulation
185
def reshape(a, newshape, order='C') -> Array: ...
186
def ravel(a, order='C') -> Array: ...
187
def flatten(a, order='C') -> Array: ...
188
189
# Transpose operations
190
def transpose(a, axes=None) -> Array: ...
191
def swapaxes(a, axis1, axis2) -> Array: ...
192
def moveaxis(a, source, destination) -> Array: ...
193
def rollaxis(a, axis, start=0) -> Array: ...
194
195
# Dimension manipulation
196
def expand_dims(a, axis) -> Array: ...
197
def squeeze(a, axis=None) -> Array: ...
198
199
# Array reversal and rotation
200
def flip(m, axis=None) -> Array: ...
201
def fliplr(m) -> Array: ...
202
def flipud(m) -> Array: ...
203
def rot90(m, k=1, axes=(0, 1)) -> Array: ...
204
def roll(a, shift, axis=None) -> Array: ...
205
206
# Broadcasting
207
def broadcast_to(array, shape) -> Array: ...
208
def broadcast_arrays(*args) -> list[Array]: ...
209
210
# Joining arrays
211
def concatenate(arrays, axis=0) -> Array: ...
212
def stack(arrays, axis=0) -> Array: ...
213
def vstack(tup) -> Array: ...
214
def hstack(tup) -> Array: ...
215
def dstack(tup) -> Array: ...
216
def column_stack(tup) -> Array: ...
217
def append(arr, values, axis=None) -> Array: ...
218
219
# Splitting arrays
220
def split(ary, indices_or_sections, axis=0) -> list[Array]: ...
221
def array_split(ary, indices_or_sections, axis=0) -> list[Array]: ...
222
def hsplit(ary, indices_or_sections) -> list[Array]: ...
223
def vsplit(ary, indices_or_sections) -> list[Array]: ...
224
def dsplit(ary, indices_or_sections) -> list[Array]: ...
225
226
# Tiling and repeating
227
def tile(A, reps) -> Array: ...
228
def repeat(a, repeats, axis=None) -> Array: ...
229
230
# Array modification
231
def insert(arr, obj, values, axis=None) -> Array: ...
232
def delete(arr, obj, axis=None) -> Array: ...
233
def place(arr, mask, vals) -> None: ...
234
def put(a, ind, v, mode='raise') -> None: ...
235
def put_along_axis(arr, indices, values, axis) -> None: ...
236
237
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, equal_nan=True) -> Array: ...
238
```
239
240
### Indexing and Selection
241
242
Advanced indexing, selection, and conditional operations.
243
244
```python { .api }
245
def take(a, indices, axis=None, mode=None) -> Array:
246
"""Take elements from array along axis."""
247
248
def take_along_axis(arr, indices, axis) -> Array:
249
"""Take values from array using indices along axis."""
250
251
def choose(a, choices, mode='raise') -> Array:
252
"""Construct array from index array and choice arrays."""
253
254
def compress(condition, a, axis=None) -> Array:
255
"""Return selected slices along axis."""
256
257
def extract(condition, arr) -> Array:
258
"""Return elements satisfying condition."""
259
260
def select(condlist, choicelist, default=0) -> Array:
261
"""Return elements chosen from choicelist based on conditions."""
262
263
def where(condition, x=None, y=None) -> Array:
264
"""Return elements chosen from x or y based on condition."""
265
266
def nonzero(a) -> tuple[Array, ...]:
267
"""Return indices of non-zero elements."""
268
269
def argwhere(a) -> Array:
270
"""Return indices where condition is True."""
271
272
def flatnonzero(a) -> Array:
273
"""Return indices of flattened array that are non-zero."""
274
275
def ix_(*args) -> tuple[Array, ...]:
276
"""Construct open mesh from multiple sequences."""
277
```
278
279
### Reduction Operations
280
281
Functions that reduce arrays along axes or compute aggregates.
282
283
```python { .api }
284
# Basic reductions
285
def sum(a, axis=None, dtype=None, keepdims=False, initial=None, where=None) -> Array: ...
286
def prod(a, axis=None, dtype=None, keepdims=False, initial=None, where=None) -> Array: ...
287
def mean(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
288
def median(a, axis=None, keepdims=False) -> Array: ...
289
def std(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
290
def var(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
291
292
# Extrema
293
def min(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
294
def max(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
295
def amin(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
296
def amax(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
297
def ptp(a, axis=None, keepdims=False) -> Array: ...
298
299
# Percentiles and quantiles
300
def percentile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
301
def quantile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
302
303
# Cumulative operations
304
def cumsum(a, axis=None, dtype=None) -> Array: ...
305
def cumprod(a, axis=None, dtype=None) -> Array: ...
306
307
# Logical reductions
308
def all(a, axis=None, keepdims=False, where=None) -> Array: ...
309
def any(a, axis=None, keepdims=False, where=None) -> Array: ...
310
311
# Counting
312
def count_nonzero(a, axis=None, keepdims=False) -> Array: ...
313
314
# NaN-aware reductions
315
def nansum(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
316
def nanprod(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
317
def nanmean(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
318
def nanmedian(a, axis=None, keepdims=False) -> Array: ...
319
def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
320
def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
321
def nanmin(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
322
def nanmax(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
323
def nanpercentile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
324
def nanquantile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
325
def nancumsum(a, axis=None, dtype=None) -> Array: ...
326
def nancumprod(a, axis=None, dtype=None) -> Array: ...
327
328
# Indices of extrema
329
def argmin(a, axis=None, keepdims=False) -> Array: ...
330
def argmax(a, axis=None, keepdims=False) -> Array: ...
331
def nanargmin(a, axis=None, keepdims=False) -> Array: ...
332
def nanargmax(a, axis=None, keepdims=False) -> Array: ...
333
```
334
335
### Linear Algebra
336
337
Core linear algebra operations for matrix computations.
338
339
```python { .api }
340
# Matrix multiplication
341
def dot(a, b) -> Array: ...
342
def matmul(x1, x2) -> Array: ...
343
def inner(a, b) -> Array: ...
344
def outer(a, b) -> Array: ...
345
def tensordot(a, b, axes=2) -> Array: ...
346
def kron(a, b) -> Array: ...
347
348
# Vector operations
349
def vdot(a, b) -> Array: ...
350
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None) -> Array: ...
351
352
# Matrix operations
353
def trace(a, offset=0, axis1=0, axis2=1, dtype=None) -> Array: ...
354
def diagonal(a, offset=0, axis1=0, axis2=1) -> Array: ...
355
def diag(v, k=0) -> Array: ...
356
def diagflat(v, k=0) -> Array: ...
357
358
# Triangular matrices
359
def tril(m, k=0) -> Array: ...
360
def triu(m, k=0) -> Array: ...
361
def tril_indices(n, k=0, m=None) -> tuple[Array, Array]: ...
362
def triu_indices(n, k=0, m=None) -> tuple[Array, Array]: ...
363
def diag_indices(n, ndim=2) -> tuple[Array, ...]: ...
364
365
# Matrix transpose
366
def matrix_transpose(x) -> Array: ...
367
```
368
369
### Sorting and Searching
370
371
Functions for sorting arrays and searching for values.
372
373
```python { .api }
374
def sort(a, axis=-1, kind='stable', order=None) -> Array: ...
375
def argsort(a, axis=-1, kind='stable', order=None) -> Array: ...
376
def lexsort(keys, axis=-1) -> Array: ...
377
def partition(a, kth, axis=-1, kind='introselect', order=None) -> Array: ...
378
def argpartition(a, kth, axis=-1, kind='introselect', order=None) -> Array: ...
379
def searchsorted(a, v, side='left', sorter=None) -> Array: ...
380
def sort_complex(a) -> Array: ...
381
```
382
383
### Set Operations
384
385
Set-like operations on arrays.
386
387
```python { .api }
388
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None) -> Array: ...
389
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False) -> Array: ...
390
def union1d(ar1, ar2) -> Array: ...
391
def setdiff1d(ar1, ar2, assume_unique=False) -> Array: ...
392
def setxor1d(ar1, ar2, assume_unique=False) -> Array: ...
393
def isin(element, test_elements, assume_unique=False, invert=False) -> Array: ...
394
```
395
396
### Statistical Functions
397
398
Statistical analysis and distribution functions.
399
400
```python { .api }
401
def bincount(x, weights=None, minlength=0, length=None) -> Array: ...
402
def histogram(a, bins=10, range=None, weights=None, density=None) -> tuple[Array, Array]: ...
403
def histogram2d(x, y, bins=10, range=None, weights=None, density=None) -> tuple[Array, Array, Array]: ...
404
def histogramdd(sample, bins=10, range=None, weights=None, density=None) -> tuple[Array, list[Array]]: ...
405
def histogram_bin_edges(a, bins=10, range=None, weights=None) -> Array: ...
406
def digitize(x, bins, right=False) -> Array: ...
407
def average(a, axis=None, weights=None, returned=False, keepdims=False) -> Array: ...
408
def corrcoef(x, y=None, rowvar=True, dtype=None) -> Array: ...
409
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None) -> Array: ...
410
def gradient(f, *varargs, axis=None, edge_order=1) -> Array: ...
411
```
412
413
### Data Types and Conversion
414
415
Type information, checking, and conversion functions.
416
417
```python { .api }
418
# Type checking
419
def issubdtype(arg1, arg2) -> bool: ...
420
def can_cast(from_, to, casting='safe') -> bool: ...
421
def result_type(*arrays_and_dtypes): ...
422
def promote_types(type1, type2): ...
423
def isscalar(element) -> bool: ...
424
def isrealobj(x) -> bool: ...
425
def iscomplexobj(x) -> bool: ...
426
427
# Type information
428
def finfo(dtype): ...
429
def iinfo(dtype): ...
430
431
# Array properties
432
def ndim(a) -> int: ...
433
def shape(a) -> tuple: ...
434
def size(a) -> int: ...
435
436
# Comparison functions
437
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool: ...
438
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) -> Array: ...
439
def array_equal(a1, a2, equal_nan=False) -> bool: ...
440
def array_equiv(a1, a2) -> bool: ...
441
442
# Utility functions
443
def copy(a, order='K') -> Array: ...
444
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None) -> Array: ...
445
```
446
447
### Comparison Operations
448
449
Element-wise comparison functions returning boolean arrays.
450
451
```python { .api }
452
def equal(x1, x2) -> Array: ...
453
def not_equal(x1, x2) -> Array: ...
454
def less(x1, x2) -> Array: ...
455
def less_equal(x1, x2) -> Array: ...
456
def greater(x1, x2) -> Array: ...
457
def greater_equal(x1, x2) -> Array: ...
458
```
459
460
### Logical Operations
461
462
Element-wise logical operations on boolean arrays.
463
464
```python { .api }
465
def logical_and(x1, x2) -> Array: ...
466
def logical_or(x1, x2) -> Array: ...
467
def logical_not(x) -> Array: ...
468
def logical_xor(x1, x2) -> Array: ...
469
```
470
471
### Bitwise Operations
472
473
Element-wise bitwise operations on integer arrays.
474
475
```python { .api }
476
def bitwise_and(x1, x2) -> Array: ...
477
def bitwise_or(x1, x2) -> Array: ...
478
def bitwise_xor(x1, x2) -> Array: ...
479
def bitwise_not(x) -> Array: ...
480
def bitwise_left_shift(x1, x2) -> Array: ...
481
def bitwise_right_shift(x1, x2) -> Array: ...
482
def left_shift(x1, x2) -> Array: ...
483
def right_shift(x1, x2) -> Array: ...
484
def invert(x) -> Array: ...
485
def bitwise_count(x) -> Array: ...
486
```
487
488
### Constants and Special Values
489
490
Mathematical and numerical constants.
491
492
```python { .api }
493
pi: float # π (3.14159...)
494
e: float # Euler's number (2.71828...)
495
euler_gamma: float # Euler-Mascheroni constant
496
inf: float # Positive infinity
497
nan: float # Not a Number
498
newaxis: None # Used for adding dimensions in indexing
499
```
500
501
## NumPy Submodules
502
503
### FFT Operations
504
505
```python { .api }
506
import jax.numpy.fft as jfft
507
508
# 1D transforms
509
jfft.fft(a, n=None, axis=-1, norm=None) -> Array
510
jfft.ifft(a, n=None, axis=-1, norm=None) -> Array
511
jfft.rfft(a, n=None, axis=-1, norm=None) -> Array
512
jfft.irfft(a, n=None, axis=-1, norm=None) -> Array
513
514
# 2D transforms
515
jfft.fft2(a, s=None, axes=(-2, -1), norm=None) -> Array
516
jfft.ifft2(a, s=None, axes=(-2, -1), norm=None) -> Array
517
jfft.rfft2(a, s=None, axes=(-2, -1), norm=None) -> Array
518
jfft.irfft2(a, s=None, axes=(-2, -1), norm=None) -> Array
519
520
# N-D transforms
521
jfft.fftn(a, s=None, axes=None, norm=None) -> Array
522
jfft.ifftn(a, s=None, axes=None, norm=None) -> Array
523
jfft.rfftn(a, s=None, axes=None, norm=None) -> Array
524
jfft.irfftn(a, s=None, axes=None, norm=None) -> Array
525
526
# Helper functions
527
jfft.fftfreq(n, d=1.0) -> Array
528
jfft.rfftfreq(n, d=1.0) -> Array
529
jfft.fftshift(x, axes=None) -> Array
530
jfft.ifftshift(x, axes=None) -> Array
531
```
532
533
### Linear Algebra Operations
534
535
```python { .api }
536
import jax.numpy.linalg as jla
537
538
# Matrix decompositions
539
jla.cholesky(a) -> Array
540
jla.qr(a, mode='reduced') -> tuple[Array, Array]
541
jla.svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]
542
jla.eig(a) -> tuple[Array, Array]
543
jla.eigh(a, UPLO='L') -> tuple[Array, Array]
544
jla.eigvals(a) -> Array
545
jla.eigvalsh(a, UPLO='L') -> Array
546
547
# Matrix properties
548
jla.det(a) -> Array
549
jla.slogdet(a) -> tuple[Array, Array]
550
jla.matrix_rank(M, tol=None, hermitian=False) -> Array
551
jla.trace(a, offset=0, axis1=0, axis2=1, dtype=None) -> Array
552
553
# Matrix solutions
554
jla.solve(a, b) -> Array
555
jla.lstsq(a, b, rcond=None) -> tuple[Array, Array, Array, Array]
556
jla.inv(a) -> Array
557
jla.pinv(a, rcond=None, hermitian=False) -> Array
558
559
# Norms and distances
560
jla.norm(x, ord=None, axis=None, keepdims=False) -> Array
561
jla.cond(x, p=None) -> Array
562
563
# Matrix functions
564
jla.matrix_power(a, n) -> Array
565
```