0
# Low-Level Operations
1
2
JAX LAX provides direct XLA operations and primitives for high-performance computing. These low-level functions offer precise control over computation and serve as building blocks for higher-level JAX operations.
3
4
## Core Imports
5
6
```python
7
import jax.lax as lax
8
from jax.lax import add, mul, dot_general, cond, scan
9
```
10
11
## Capabilities
12
13
### Arithmetic Operations
14
15
Element-wise arithmetic operations that map directly to XLA primitives.
16
17
```python { .api }
18
def add(x, y) -> Array:
19
"""Element-wise addition."""
20
21
def sub(x, y) -> Array:
22
"""Element-wise subtraction."""
23
24
def mul(x, y) -> Array:
25
"""Element-wise multiplication."""
26
27
def div(x, y) -> Array:
28
"""Element-wise division."""
29
30
def rem(x, y) -> Array:
31
"""Element-wise remainder."""
32
33
def max(x, y) -> Array:
34
"""Element-wise maximum."""
35
36
def min(x, y) -> Array:
37
"""Element-wise minimum."""
38
39
def abs(x) -> Array:
40
"""Element-wise absolute value."""
41
42
def neg(x) -> Array:
43
"""Element-wise negation."""
44
45
def sign(x) -> Array:
46
"""Element-wise sign function."""
47
48
def pow(x, y) -> Array:
49
"""Element-wise power operation."""
50
51
def integer_pow(x, y) -> Array:
52
"""Element-wise integer power."""
53
54
def reciprocal(x) -> Array:
55
"""Element-wise reciprocal (1/x)."""
56
57
def square(x) -> Array:
58
"""Element-wise square."""
59
60
def sqrt(x) -> Array:
61
"""Element-wise square root."""
62
63
def rsqrt(x) -> Array:
64
"""Element-wise reciprocal square root (1/√x)."""
65
66
def cbrt(x) -> Array:
67
"""Element-wise cube root."""
68
69
def clamp(min, x, max) -> Array:
70
"""
71
Clamp values between minimum and maximum.
72
73
Args:
74
min: Minimum value
75
x: Input array
76
max: Maximum value
77
78
Returns:
79
Array with values clamped to [min, max]
80
"""
81
```
82
83
### Mathematical Functions
84
85
Transcendental and special mathematical functions.
86
87
```python { .api }
88
# Trigonometric functions
89
def sin(x) -> Array: ...
90
def cos(x) -> Array: ...
91
def tan(x) -> Array: ...
92
def asin(x) -> Array: ...
93
def acos(x) -> Array: ...
94
def atan(x) -> Array: ...
95
def atan2(x, y) -> Array: ...
96
97
# Hyperbolic functions
98
def sinh(x) -> Array: ...
99
def cosh(x) -> Array: ...
100
def tanh(x) -> Array: ...
101
def asinh(x) -> Array: ...
102
def acosh(x) -> Array: ...
103
def atanh(x) -> Array: ...
104
105
# Exponential and logarithmic
106
def exp(x) -> Array: ...
107
def exp2(x) -> Array: ...
108
def expm1(x) -> Array: ...
109
def log(x) -> Array: ...
110
def log1p(x) -> Array: ...
111
def logistic(x) -> Array: ...
112
113
# Rounding operations
114
def ceil(x) -> Array: ...
115
def floor(x) -> Array: ...
116
def round(x) -> Array: ...
117
118
# Complex number operations
119
def complex(real, imag) -> Array:
120
"""Create complex array from real and imaginary parts."""
121
122
def conj(x) -> Array:
123
"""Complex conjugate."""
124
125
def real(x) -> Array:
126
"""Extract real part of complex array."""
127
128
def imag(x) -> Array:
129
"""Extract imaginary part of complex array."""
130
```
131
132
### Comparison Operations
133
134
Element-wise comparison operations returning boolean arrays.
135
136
```python { .api }
137
def eq(x, y) -> Array:
138
"""Element-wise equality."""
139
140
def ne(x, y) -> Array:
141
"""Element-wise inequality."""
142
143
def lt(x, y) -> Array:
144
"""Element-wise less than."""
145
146
def le(x, y) -> Array:
147
"""Element-wise less than or equal."""
148
149
def gt(x, y) -> Array:
150
"""Element-wise greater than."""
151
152
def ge(x, y) -> Array:
153
"""Element-wise greater than or equal."""
154
155
def is_finite(x) -> Array:
156
"""Element-wise finite number test."""
157
```
158
159
### Bitwise Operations
160
161
Bitwise operations on integer arrays.
162
163
```python { .api }
164
# Bitwise operations
165
def bitwise_and(x, y) -> Array: ...
166
def bitwise_or(x, y) -> Array: ...
167
def bitwise_xor(x, y) -> Array: ...
168
def bitwise_not(x) -> Array: ...
169
170
# Bit shifting
171
def shift_left(x, y) -> Array: ...
172
def shift_right_logical(x, y) -> Array: ...
173
def shift_right_arithmetic(x, y) -> Array: ...
174
175
# Bit manipulation
176
def clz(x) -> Array:
177
"""Count leading zeros."""
178
179
def population_count(x) -> Array:
180
"""Count set bits."""
181
```
182
183
### Array Operations
184
185
Shape manipulation, broadcasting, and array transformation operations.
186
187
```python { .api }
188
def broadcast(operand, sizes) -> Array:
189
"""Broadcast array by adding dimensions."""
190
191
def broadcast_in_dim(operand, shape, broadcast_dimensions) -> Array:
192
"""Broadcast array into target shape."""
193
194
def reshape(operand, new_sizes, dimensions=None) -> Array:
195
"""Reshape array to new dimensions."""
196
197
def transpose(operand, permutation) -> Array:
198
"""Transpose array axes."""
199
200
def rev(operand, dimensions) -> Array:
201
"""Reverse array along specified dimensions."""
202
203
def concatenate(operands, dimension) -> Array:
204
"""Concatenate arrays along dimension."""
205
206
def pad(operand, padding_value, padding_config) -> Array:
207
"""Pad array with constant value."""
208
209
def squeeze(array, dimensions) -> Array:
210
"""Remove unit dimensions."""
211
212
def expand_dims(array, dimensions) -> Array:
213
"""Add unit dimensions."""
214
```
215
216
### Indexing and Slicing
217
218
Advanced indexing operations for array access and updates.
219
220
```python { .api }
221
def slice(operand, start_indices, limit_indices, strides=None) -> Array:
222
"""Extract slice from array."""
223
224
def slice_in_dim(operand, start, limit, stride=1, axis=0) -> Array:
225
"""Slice array along single dimension."""
226
227
def dynamic_slice(operand, start_indices, slice_sizes) -> Array:
228
"""Extract slice with dynamic start indices."""
229
230
def dynamic_slice_in_dim(operand, start, size, axis=0) -> Array:
231
"""Dynamic slice along single dimension."""
232
233
def dynamic_update_slice(operand, update, start_indices) -> Array:
234
"""Update slice with dynamic start indices."""
235
236
def dynamic_update_slice_in_dim(operand, update, start, axis) -> Array:
237
"""Dynamic update slice along single dimension."""
238
239
def gather(
240
operand,
241
start_indices,
242
dimension_numbers,
243
slice_sizes,
244
indices_are_sorted=False,
245
unique_indices=False,
246
mode=None,
247
fill_value=None
248
) -> Array:
249
"""General gather operation for advanced indexing."""
250
251
def scatter(
252
operand,
253
scatter_indices,
254
updates,
255
dimension_numbers,
256
indices_are_sorted=False,
257
unique_indices=False,
258
mode=None
259
) -> Array:
260
"""General scatter operation for advanced updates."""
261
262
# Scatter variants for different operations
263
def scatter_add(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
264
def scatter_sub(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
265
def scatter_mul(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
266
def scatter_max(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
267
def scatter_min(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
268
269
def index_in_dim(operand, index, axis=0, keepdims=True) -> Array:
270
"""Index array along single dimension."""
271
272
def index_take(src, idxs, axes) -> Array:
273
"""Take elements using multi-dimensional indices."""
274
```
275
276
### Reduction Operations
277
278
Reduce arrays along specified axes using various operations.
279
280
```python { .api }
281
def reduce(
282
operand,
283
init_value,
284
computation,
285
dimensions
286
) -> Array:
287
"""
288
General reduction operation.
289
290
Args:
291
operand: Array to reduce
292
init_value: Initial value for reduction
293
computation: Binary function for reduction
294
dimensions: Axes to reduce over
295
296
Returns:
297
Reduced array
298
"""
299
300
# Specialized reductions
301
def reduce_sum(operand, axes) -> Array: ...
302
def reduce_prod(operand, axes) -> Array: ...
303
def reduce_max(operand, axes) -> Array: ...
304
def reduce_min(operand, axes) -> Array: ...
305
def reduce_and(operand, axes) -> Array: ...
306
def reduce_or(operand, axes) -> Array: ...
307
def reduce_xor(operand, axes) -> Array: ...
308
309
# Windowed reductions
310
def reduce_window(
311
operand,
312
init_value,
313
computation,
314
window_dimensions,
315
window_strides=None,
316
padding=None,
317
base_dilation=None,
318
window_dilation=None
319
) -> Array:
320
"""
321
Sliding window reduction.
322
323
Args:
324
operand: Input array
325
init_value: Initial value for reduction
326
computation: Binary reduction function
327
window_dimensions: Size of sliding window
328
window_strides: Stride of sliding window
329
padding: Padding specification
330
base_dilation: Base dilation factor
331
window_dilation: Window dilation factor
332
333
Returns:
334
Reduced array with window operation applied
335
"""
336
```
337
338
### Control Flow
339
340
Conditional execution and loop constructs for dynamic computation graphs.
341
342
```python { .api }
343
def cond(pred, true_fun, false_fun, *operands) -> Any:
344
"""
345
Conditional execution based on predicate.
346
347
Args:
348
pred: Boolean scalar predicate
349
true_fun: Function to execute if pred is True
350
false_fun: Function to execute if pred is False
351
operands: Arguments to pass to selected function
352
353
Returns:
354
Result of executing selected function
355
"""
356
357
def select(pred, on_true, on_false) -> Array:
358
"""Element-wise conditional selection."""
359
360
def select_n(which, *cases) -> Array:
361
"""Multi-way conditional selection."""
362
363
def while_loop(cond_fun, body_fun, init_val) -> Any:
364
"""
365
While loop with condition and body functions.
366
367
Args:
368
cond_fun: Function that returns boolean condition
369
body_fun: Function that updates loop state
370
init_val: Initial loop state
371
372
Returns:
373
Final loop state after termination
374
"""
375
376
def fori_loop(lower, upper, body_fun, init_val) -> Any:
377
"""
378
For loop over range with body function.
379
380
Args:
381
lower: Loop start index
382
upper: Loop end index (exclusive)
383
body_fun: Function that updates state (takes index and state)
384
init_val: Initial loop state
385
386
Returns:
387
Final loop state
388
"""
389
390
def scan(f, init, xs, length=None, reverse=False, unroll=1) -> tuple[Any, Array]:
391
"""
392
Scan operation applying function over sequence.
393
394
Args:
395
f: Function to apply (takes carry and input, returns new carry and output)
396
init: Initial carry value
397
xs: Input sequence
398
length: Length of sequence (inferred if None)
399
reverse: Whether to scan in reverse
400
unroll: Number of iterations to unroll
401
402
Returns:
403
Tuple of (final_carry, outputs)
404
"""
405
406
def associative_scan(fn, elems, reverse=False, axis=0) -> Array:
407
"""
408
Parallel associative scan operation.
409
410
Args:
411
fn: Associative binary function
412
elems: Input sequence
413
reverse: Whether to scan in reverse
414
axis: Axis to scan along
415
416
Returns:
417
Scanned results
418
"""
419
420
def switch(index, branches, *operands) -> Any:
421
"""
422
Switch statement for multi-way branching.
423
424
Args:
425
index: Integer index selecting branch
426
branches: List of functions (branches)
427
operands: Arguments to pass to selected branch
428
429
Returns:
430
Result of executing selected branch
431
"""
432
433
def map(f, xs) -> Array:
434
"""Map function over leading axis of array."""
435
```
436
437
### Cumulative Operations
438
439
Cumulative operations along array axes.
440
441
```python { .api }
442
def cumsum(operand, axis=None, reverse=False) -> Array:
443
"""Cumulative sum along axis."""
444
445
def cumprod(operand, axis=None, reverse=False) -> Array:
446
"""Cumulative product along axis."""
447
448
def cummax(operand, axis=None, reverse=False) -> Array:
449
"""Cumulative maximum along axis."""
450
451
def cummin(operand, axis=None, reverse=False) -> Array:
452
"""Cumulative minimum along axis."""
453
454
def cumlogsumexp(operand, axis=None, reverse=False) -> Array:
455
"""Cumulative log-sum-exp along axis."""
456
```
457
458
### Linear Algebra
459
460
Matrix operations and linear algebra primitives.
461
462
```python { .api }
463
def dot(lhs, rhs, precision=None, preferred_element_type=None) -> Array:
464
"""Matrix multiplication for 1D and 2D arrays."""
465
466
def dot_general(
467
lhs,
468
rhs,
469
dimension_numbers,
470
precision=None,
471
preferred_element_type=None
472
) -> Array:
473
"""
474
General matrix multiplication with custom contractions.
475
476
Args:
477
lhs: Left-hand side array
478
rhs: Right-hand side array
479
dimension_numbers: Specification of contraction and batch dimensions
480
precision: Computation precision
481
preferred_element_type: Preferred output element type
482
483
Returns:
484
Result of general matrix multiplication
485
"""
486
487
def batch_matmul(
488
lhs,
489
rhs,
490
precision=None,
491
preferred_element_type=None
492
) -> Array:
493
"""Batched matrix multiplication."""
494
495
class DotDimensionNumbers:
496
"""Dimension specification for dot_general operation."""
497
lhs_contracting_dimensions: tuple[int, ...]
498
rhs_contracting_dimensions: tuple[int, ...]
499
lhs_batch_dimensions: tuple[int, ...]
500
rhs_batch_dimensions: tuple[int, ...]
501
```
502
503
### Advanced Linear Algebra (lax.linalg)
504
505
Advanced linear algebra operations from `jax.lax.linalg`.
506
507
```python { .api }
508
def cholesky(a, *, symmetrize_input: bool = True) -> Array:
509
"""
510
Cholesky decomposition of positive definite matrix.
511
512
Args:
513
a: Positive definite matrix
514
symmetrize_input: Whether to symmetrize input
515
516
Returns:
517
Lower triangular Cholesky factor
518
"""
519
520
def cholesky_update(r, u, *, alpha: float = 1.0) -> Array:
521
"""
522
Rank-1 update to Cholesky factorization.
523
524
Args:
525
r: Cholesky factor
526
u: Update vector
527
alpha: Update coefficient
528
529
Returns:
530
Updated Cholesky factor
531
"""
532
533
def eig(a, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> tuple[Array, Array, Array]:
534
"""
535
Eigenvalue decomposition of general matrix.
536
537
Args:
538
a: Input matrix
539
compute_left_eigenvectors: Whether to compute left eigenvectors
540
compute_right_eigenvectors: Whether to compute right eigenvectors
541
542
Returns:
543
Tuple of (eigenvalues, left_eigenvectors, right_eigenvectors)
544
"""
545
546
def eigh(a, *, lower: bool = True, symmetrize_input: bool = True, sort_eigenvalues: bool = True) -> tuple[Array, Array]:
547
"""
548
Eigenvalue decomposition of Hermitian matrix.
549
550
Args:
551
a: Hermitian matrix
552
lower: Whether to use lower triangle
553
symmetrize_input: Whether to symmetrize input
554
sort_eigenvalues: Whether to sort eigenvalues
555
556
Returns:
557
Tuple of (eigenvalues, eigenvectors)
558
"""
559
560
def lu(a) -> tuple[Array, Array, Array]:
561
"""
562
LU decomposition with partial pivoting.
563
564
Args:
565
a: Input matrix
566
567
Returns:
568
Tuple of (lu_factors, pivots, permutation)
569
"""
570
571
def qr(a, *, full_matrices: bool = True) -> tuple[Array, Array]:
572
"""
573
QR decomposition.
574
575
Args:
576
a: Input matrix
577
full_matrices: Whether to return full or reduced QR
578
579
Returns:
580
Tuple of (q, r) matrices
581
"""
582
583
def svd(a, *, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> tuple[Array, Array, Array]:
584
"""
585
Singular value decomposition.
586
587
Args:
588
a: Input matrix
589
full_matrices: Whether to return full or reduced SVD
590
compute_uv: Whether to compute U and V matrices
591
hermitian: Whether matrix is Hermitian
592
593
Returns:
594
Tuple of (u, s, vh) where A = U @ diag(s) @ Vh
595
"""
596
597
def schur(a, *, compute_schur_vectors: bool = True, sort_eigs: bool = False, select_callable=None) -> tuple[Array, Array]:
598
"""
599
Schur decomposition.
600
601
Args:
602
a: Input matrix
603
compute_schur_vectors: Whether to compute Schur vectors
604
sort_eigs: Whether to sort eigenvalues
605
select_callable: Selection function for eigenvalues
606
607
Returns:
608
Tuple of (schur_form, schur_vectors)
609
"""
610
611
def hessenberg(a) -> tuple[Array, Array]:
612
"""
613
Hessenberg decomposition.
614
615
Args:
616
a: Input matrix
617
618
Returns:
619
Tuple of (hessenberg_form, orthogonal_matrix)
620
"""
621
622
def triangular_solve(a, b, *, left_side: bool = True, lower: bool = True, transpose_a: bool = False, conjugate_a: bool = False, unit_diagonal: bool = False) -> Array:
623
"""
624
Solve triangular system of equations.
625
626
Args:
627
a: Triangular matrix
628
b: Right-hand side
629
left_side: Whether A is on left side (Ax = b) or right (xA = b)
630
lower: Whether A is lower triangular
631
transpose_a: Whether to transpose A
632
conjugate_a: Whether to conjugate A
633
unit_diagonal: Whether A has unit diagonal
634
635
Returns:
636
Solution to triangular system
637
"""
638
639
def tridiagonal(a, d, *, lower: bool = True) -> tuple[Array, Array]:
640
"""
641
Tridiagonal reduction of symmetric matrix.
642
643
Args:
644
a: Symmetric matrix
645
d: Diagonal elements
646
lower: Whether to use lower triangle
647
648
Returns:
649
Tuple of (tridiagonal_matrix, orthogonal_matrix)
650
"""
651
652
def tridiagonal_solve(dl, d, du, b) -> Array:
653
"""
654
Solve tridiagonal system using Thomas algorithm.
655
656
Args:
657
dl: Lower diagonal
658
d: Main diagonal
659
du: Upper diagonal
660
b: Right-hand side
661
662
Returns:
663
Solution to tridiagonal system
664
"""
665
666
def qdwh(a, *, is_hermitian: bool = False, max_iterations: int = None, dynamic_shape: bool = False) -> tuple[Array, Array]:
667
"""
668
QDWH polar decomposition: A = UP where U is unitary, P is positive semidefinite.
669
670
Args:
671
a: Input matrix
672
is_hermitian: Whether matrix is Hermitian
673
max_iterations: Maximum number of iterations
674
dynamic_shape: Whether to handle dynamic shapes
675
676
Returns:
677
Tuple of (unitary_factor, positive_factor)
678
"""
679
680
def householder_product(a, taus) -> Array:
681
"""
682
Compute product of Householder reflectors.
683
684
Args:
685
a: Matrix containing Householder vectors
686
taus: Householder scaling factors
687
688
Returns:
689
Product of Householder reflectors
690
"""
691
692
def lu_pivots_to_permutation(pivots, permutation_size) -> Array:
693
"""
694
Convert LU pivots to permutation matrix.
695
696
Args:
697
pivots: Pivot indices from LU decomposition
698
permutation_size: Size of permutation matrix
699
700
Returns:
701
Permutation matrix
702
"""
703
```
704
705
### Convolution Operations
706
707
Convolution operations for neural networks and signal processing.
708
709
```python { .api }
710
def conv(
711
lhs,
712
rhs,
713
window_strides,
714
padding,
715
precision=None,
716
preferred_element_type=None
717
) -> Array:
718
"""Basic convolution operation."""
719
720
def conv_general_dilated(
721
lhs,
722
rhs,
723
window_strides,
724
padding,
725
lhs_dilation=None,
726
rhs_dilation=None,
727
dimension_numbers=None,
728
feature_group_count=1,
729
batch_group_count=1,
730
precision=None,
731
preferred_element_type=None
732
) -> Array:
733
"""
734
General dilated convolution with full configuration options.
735
736
Args:
737
lhs: Input array (N...HWC or NCHW... format)
738
rhs: Kernel array
739
window_strides: Convolution strides
740
padding: Padding specification
741
lhs_dilation: Input dilation
742
rhs_dilation: Kernel dilation (atrous convolution)
743
dimension_numbers: Dimension layout specification
744
feature_group_count: Number of feature groups
745
batch_group_count: Number of batch groups
746
precision: Computation precision
747
preferred_element_type: Preferred output type
748
749
Returns:
750
Convolution result
751
"""
752
753
def conv_transpose(
754
lhs,
755
rhs,
756
strides,
757
padding,
758
rhs_dilation=None,
759
dimension_numbers=None,
760
transpose_kernel=False,
761
precision=None,
762
preferred_element_type=None
763
) -> Array:
764
"""Transposed (deconvolution) operation."""
765
766
class ConvDimensionNumbers:
767
"""Convolution dimension number specification."""
768
lhs_spec: tuple[int, ...] # Input dimension specification
769
rhs_spec: tuple[int, ...] # Kernel dimension specification
770
out_spec: tuple[int, ...] # Output dimension specification
771
```
772
773
### FFT Operations
774
775
Fast Fourier Transform operations.
776
777
```python { .api }
778
def fft(a, fft_type, fft_lengths) -> Array:
779
"""
780
Fast Fourier Transform.
781
782
Args:
783
a: Input array
784
fft_type: Type of FFT (from FftType enum)
785
fft_lengths: Lengths of FFT dimensions
786
787
Returns:
788
FFT result
789
"""
790
791
class FftType:
792
"""FFT type enumeration."""
793
FFT = "FFT"
794
IFFT = "IFFT"
795
RFFT = "RFFT"
796
IRFFT = "IRFFT"
797
```
798
799
### Parallel Operations
800
801
Multi-device communication primitives for distributed computing.
802
803
```python { .api }
804
def all_gather(x, axis_name, *, axis_index_groups=None, tiled=False) -> Array:
805
"""Gather values from all devices."""
806
807
def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False) -> Array:
808
"""All-to-all communication between devices."""
809
810
def psum(x, axis_name, *, axis_index_groups=None) -> Array:
811
"""Parallel sum reduction across devices."""
812
813
def pmean(x, axis_name, *, axis_index_groups=None) -> Array:
814
"""Parallel mean reduction across devices."""
815
816
def pmax(x, axis_name, *, axis_index_groups=None) -> Array:
817
"""Parallel max reduction across devices."""
818
819
def pmin(x, axis_name, *, axis_index_groups=None) -> Array:
820
"""Parallel min reduction across devices."""
821
822
def ppermute(x, axis_name, perm, *, axis_index_groups=None) -> Array:
823
"""Permute data between devices."""
824
825
def axis_index(axis_name) -> Array:
826
"""Get device index along named axis."""
827
828
def axis_size(axis_name) -> int:
829
"""Get number of devices along named axis."""
830
831
def pbroadcast(x, axis_name, *, axis_index_groups=None) -> Array:
832
"""Broadcast from first device to all others."""
833
```
834
835
### Special Functions
836
837
Special mathematical functions and probability distributions.
838
839
```python { .api }
840
# Error functions
841
def erf(x) -> Array: ...
842
def erfc(x) -> Array: ...
843
def erf_inv(x) -> Array: ...
844
845
# Gamma functions
846
def lgamma(x) -> Array: ...
847
def digamma(x) -> Array: ...
848
def polygamma(m, x) -> Array: ...
849
850
# Bessel functions
851
def bessel_i0e(x) -> Array: ...
852
def bessel_i1e(x) -> Array: ...
853
854
# Other special functions
855
def betainc(a, b, x) -> Array: ...
856
def igamma(a, x) -> Array: ...
857
def igammac(a, x) -> Array: ...
858
def zeta(x, q=None) -> Array: ...
859
```
860
861
### Type Conversion and Manipulation
862
863
Array type conversion and data manipulation operations.
864
865
```python { .api }
866
def convert_element_type(operand, new_dtype) -> Array:
867
"""Convert array element type."""
868
869
def bitcast_convert_type(operand, new_dtype) -> Array:
870
"""Bitcast array to new type without changing bit representation."""
871
872
def dtype(x) -> numpy.dtype:
873
"""Get array data type."""
874
875
def full(shape, fill_value, dtype=None) -> Array:
876
"""Create array filled with constant value."""
877
878
def full_like(x, fill_value, dtype=None, shape=None) -> Array:
879
"""Create filled array with same properties as input."""
880
881
def iota(dtype, size) -> Array:
882
"""Create array with sequential values (0, 1, 2, ...)."""
883
884
def broadcasted_iota(dtype, shape, dimension) -> Array:
885
"""Create iota array broadcasted to shape."""
886
```
887
888
### Sorting Operations
889
890
Sorting and selection operations.
891
892
```python { .api }
893
def sort(operand, dimension=-1, is_stable=True) -> Array:
894
"""Sort array along dimension."""
895
896
def sort_key_val(keys, values, dimension=-1, is_stable=True) -> tuple[Array, Array]:
897
"""Sort key-value pairs."""
898
899
def top_k(operand, k) -> tuple[Array, Array]:
900
"""Find top k largest elements and their indices."""
901
902
def argmax(operand, axis=None, index_dtype=int) -> Array:
903
"""Indices of maximum values."""
904
905
def argmin(operand, axis=None, index_dtype=int) -> Array:
906
"""Indices of minimum values."""
907
```
908
909
### Miscellaneous Operations
910
911
Additional utility operations and performance primitives.
912
913
```python { .api }
914
def stop_gradient(x) -> Array:
915
"""Stop gradient computation at this point."""
916
917
def optimization_barrier(x) -> Array:
918
"""Prevent optimization across this point."""
919
920
def nextafter(x1, x2) -> Array:
921
"""Next representable value after x1 in direction of x2."""
922
923
def reduce_precision(operand, exponent_bits, mantissa_bits) -> Array:
924
"""Reduce floating-point precision."""
925
926
def create_token() -> Array:
927
"""Create execution token for ordering side effects."""
928
929
def after_all(*tokens) -> Array:
930
"""Create token that depends on all input tokens."""
931
932
# Random number generation primitives
933
def rng_uniform(a, b, shape, dtype=None) -> Array:
934
"""Low-level uniform random number generation."""
935
936
def rng_bit_generator(key, shape, dtype=None, algorithm=None) -> tuple[Array, Array]:
937
"""Low-level random bit generation."""
938
```