or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

low-level-ops.mddocs/

0

# Low-Level Operations

1

2

JAX LAX provides direct XLA operations and primitives for high-performance computing. These low-level functions offer precise control over computation and serve as building blocks for higher-level JAX operations.

3

4

## Core Imports

5

6

```python

7

import jax.lax as lax

8

from jax.lax import add, mul, dot_general, cond, scan

9

```

10

11

## Capabilities

12

13

### Arithmetic Operations

14

15

Element-wise arithmetic operations that map directly to XLA primitives.

16

17

```python { .api }

18

def add(x, y) -> Array:

19

"""Element-wise addition."""

20

21

def sub(x, y) -> Array:

22

"""Element-wise subtraction."""

23

24

def mul(x, y) -> Array:

25

"""Element-wise multiplication."""

26

27

def div(x, y) -> Array:

28

"""Element-wise division."""

29

30

def rem(x, y) -> Array:

31

"""Element-wise remainder."""

32

33

def max(x, y) -> Array:

34

"""Element-wise maximum."""

35

36

def min(x, y) -> Array:

37

"""Element-wise minimum."""

38

39

def abs(x) -> Array:

40

"""Element-wise absolute value."""

41

42

def neg(x) -> Array:

43

"""Element-wise negation."""

44

45

def sign(x) -> Array:

46

"""Element-wise sign function."""

47

48

def pow(x, y) -> Array:

49

"""Element-wise power operation."""

50

51

def integer_pow(x, y) -> Array:

52

"""Element-wise integer power."""

53

54

def reciprocal(x) -> Array:

55

"""Element-wise reciprocal (1/x)."""

56

57

def square(x) -> Array:

58

"""Element-wise square."""

59

60

def sqrt(x) -> Array:

61

"""Element-wise square root."""

62

63

def rsqrt(x) -> Array:

64

"""Element-wise reciprocal square root (1/√x)."""

65

66

def cbrt(x) -> Array:

67

"""Element-wise cube root."""

68

69

def clamp(min, x, max) -> Array:

70

"""

71

Clamp values between minimum and maximum.

72

73

Args:

74

min: Minimum value

75

x: Input array

76

max: Maximum value

77

78

Returns:

79

Array with values clamped to [min, max]

80

"""

81

```

82

83

### Mathematical Functions

84

85

Transcendental and special mathematical functions.

86

87

```python { .api }

88

# Trigonometric functions

89

def sin(x) -> Array: ...

90

def cos(x) -> Array: ...

91

def tan(x) -> Array: ...

92

def asin(x) -> Array: ...

93

def acos(x) -> Array: ...

94

def atan(x) -> Array: ...

95

def atan2(x, y) -> Array: ...

96

97

# Hyperbolic functions

98

def sinh(x) -> Array: ...

99

def cosh(x) -> Array: ...

100

def tanh(x) -> Array: ...

101

def asinh(x) -> Array: ...

102

def acosh(x) -> Array: ...

103

def atanh(x) -> Array: ...

104

105

# Exponential and logarithmic

106

def exp(x) -> Array: ...

107

def exp2(x) -> Array: ...

108

def expm1(x) -> Array: ...

109

def log(x) -> Array: ...

110

def log1p(x) -> Array: ...

111

def logistic(x) -> Array: ...

112

113

# Rounding operations

114

def ceil(x) -> Array: ...

115

def floor(x) -> Array: ...

116

def round(x) -> Array: ...

117

118

# Complex number operations

119

def complex(real, imag) -> Array:

120

"""Create complex array from real and imaginary parts."""

121

122

def conj(x) -> Array:

123

"""Complex conjugate."""

124

125

def real(x) -> Array:

126

"""Extract real part of complex array."""

127

128

def imag(x) -> Array:

129

"""Extract imaginary part of complex array."""

130

```

131

132

### Comparison Operations

133

134

Element-wise comparison operations returning boolean arrays.

135

136

```python { .api }

137

def eq(x, y) -> Array:

138

"""Element-wise equality."""

139

140

def ne(x, y) -> Array:

141

"""Element-wise inequality."""

142

143

def lt(x, y) -> Array:

144

"""Element-wise less than."""

145

146

def le(x, y) -> Array:

147

"""Element-wise less than or equal."""

148

149

def gt(x, y) -> Array:

150

"""Element-wise greater than."""

151

152

def ge(x, y) -> Array:

153

"""Element-wise greater than or equal."""

154

155

def is_finite(x) -> Array:

156

"""Element-wise finite number test."""

157

```

158

159

### Bitwise Operations

160

161

Bitwise operations on integer arrays.

162

163

```python { .api }

164

# Bitwise operations

165

def bitwise_and(x, y) -> Array: ...

166

def bitwise_or(x, y) -> Array: ...

167

def bitwise_xor(x, y) -> Array: ...

168

def bitwise_not(x) -> Array: ...

169

170

# Bit shifting

171

def shift_left(x, y) -> Array: ...

172

def shift_right_logical(x, y) -> Array: ...

173

def shift_right_arithmetic(x, y) -> Array: ...

174

175

# Bit manipulation

176

def clz(x) -> Array:

177

"""Count leading zeros."""

178

179

def population_count(x) -> Array:

180

"""Count set bits."""

181

```

182

183

### Array Operations

184

185

Shape manipulation, broadcasting, and array transformation operations.

186

187

```python { .api }

188

def broadcast(operand, sizes) -> Array:

189

"""Broadcast array by adding dimensions."""

190

191

def broadcast_in_dim(operand, shape, broadcast_dimensions) -> Array:

192

"""Broadcast array into target shape."""

193

194

def reshape(operand, new_sizes, dimensions=None) -> Array:

195

"""Reshape array to new dimensions."""

196

197

def transpose(operand, permutation) -> Array:

198

"""Transpose array axes."""

199

200

def rev(operand, dimensions) -> Array:

201

"""Reverse array along specified dimensions."""

202

203

def concatenate(operands, dimension) -> Array:

204

"""Concatenate arrays along dimension."""

205

206

def pad(operand, padding_value, padding_config) -> Array:

207

"""Pad array with constant value."""

208

209

def squeeze(array, dimensions) -> Array:

210

"""Remove unit dimensions."""

211

212

def expand_dims(array, dimensions) -> Array:

213

"""Add unit dimensions."""

214

```

215

216

### Indexing and Slicing

217

218

Advanced indexing operations for array access and updates.

219

220

```python { .api }

221

def slice(operand, start_indices, limit_indices, strides=None) -> Array:

222

"""Extract slice from array."""

223

224

def slice_in_dim(operand, start, limit, stride=1, axis=0) -> Array:

225

"""Slice array along single dimension."""

226

227

def dynamic_slice(operand, start_indices, slice_sizes) -> Array:

228

"""Extract slice with dynamic start indices."""

229

230

def dynamic_slice_in_dim(operand, start, size, axis=0) -> Array:

231

"""Dynamic slice along single dimension."""

232

233

def dynamic_update_slice(operand, update, start_indices) -> Array:

234

"""Update slice with dynamic start indices."""

235

236

def dynamic_update_slice_in_dim(operand, update, start, axis) -> Array:

237

"""Dynamic update slice along single dimension."""

238

239

def gather(

240

operand,

241

start_indices,

242

dimension_numbers,

243

slice_sizes,

244

indices_are_sorted=False,

245

unique_indices=False,

246

mode=None,

247

fill_value=None

248

) -> Array:

249

"""General gather operation for advanced indexing."""

250

251

def scatter(

252

operand,

253

scatter_indices,

254

updates,

255

dimension_numbers,

256

indices_are_sorted=False,

257

unique_indices=False,

258

mode=None

259

) -> Array:

260

"""General scatter operation for advanced updates."""

261

262

# Scatter variants for different operations

263

def scatter_add(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...

264

def scatter_sub(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...

265

def scatter_mul(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...

266

def scatter_max(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...

267

def scatter_min(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...

268

269

def index_in_dim(operand, index, axis=0, keepdims=True) -> Array:

270

"""Index array along single dimension."""

271

272

def index_take(src, idxs, axes) -> Array:

273

"""Take elements using multi-dimensional indices."""

274

```

275

276

### Reduction Operations

277

278

Reduce arrays along specified axes using various operations.

279

280

```python { .api }

281

def reduce(

282

operand,

283

init_value,

284

computation,

285

dimensions

286

) -> Array:

287

"""

288

General reduction operation.

289

290

Args:

291

operand: Array to reduce

292

init_value: Initial value for reduction

293

computation: Binary function for reduction

294

dimensions: Axes to reduce over

295

296

Returns:

297

Reduced array

298

"""

299

300

# Specialized reductions

301

def reduce_sum(operand, axes) -> Array: ...

302

def reduce_prod(operand, axes) -> Array: ...

303

def reduce_max(operand, axes) -> Array: ...

304

def reduce_min(operand, axes) -> Array: ...

305

def reduce_and(operand, axes) -> Array: ...

306

def reduce_or(operand, axes) -> Array: ...

307

def reduce_xor(operand, axes) -> Array: ...

308

309

# Windowed reductions

310

def reduce_window(

311

operand,

312

init_value,

313

computation,

314

window_dimensions,

315

window_strides=None,

316

padding=None,

317

base_dilation=None,

318

window_dilation=None

319

) -> Array:

320

"""

321

Sliding window reduction.

322

323

Args:

324

operand: Input array

325

init_value: Initial value for reduction

326

computation: Binary reduction function

327

window_dimensions: Size of sliding window

328

window_strides: Stride of sliding window

329

padding: Padding specification

330

base_dilation: Base dilation factor

331

window_dilation: Window dilation factor

332

333

Returns:

334

Reduced array with window operation applied

335

"""

336

```

337

338

### Control Flow

339

340

Conditional execution and loop constructs for dynamic computation graphs.

341

342

```python { .api }

343

def cond(pred, true_fun, false_fun, *operands) -> Any:

344

"""

345

Conditional execution based on predicate.

346

347

Args:

348

pred: Boolean scalar predicate

349

true_fun: Function to execute if pred is True

350

false_fun: Function to execute if pred is False

351

operands: Arguments to pass to selected function

352

353

Returns:

354

Result of executing selected function

355

"""

356

357

def select(pred, on_true, on_false) -> Array:

358

"""Element-wise conditional selection."""

359

360

def select_n(which, *cases) -> Array:

361

"""Multi-way conditional selection."""

362

363

def while_loop(cond_fun, body_fun, init_val) -> Any:

364

"""

365

While loop with condition and body functions.

366

367

Args:

368

cond_fun: Function that returns boolean condition

369

body_fun: Function that updates loop state

370

init_val: Initial loop state

371

372

Returns:

373

Final loop state after termination

374

"""

375

376

def fori_loop(lower, upper, body_fun, init_val) -> Any:

377

"""

378

For loop over range with body function.

379

380

Args:

381

lower: Loop start index

382

upper: Loop end index (exclusive)

383

body_fun: Function that updates state (takes index and state)

384

init_val: Initial loop state

385

386

Returns:

387

Final loop state

388

"""

389

390

def scan(f, init, xs, length=None, reverse=False, unroll=1) -> tuple[Any, Array]:

391

"""

392

Scan operation applying function over sequence.

393

394

Args:

395

f: Function to apply (takes carry and input, returns new carry and output)

396

init: Initial carry value

397

xs: Input sequence

398

length: Length of sequence (inferred if None)

399

reverse: Whether to scan in reverse

400

unroll: Number of iterations to unroll

401

402

Returns:

403

Tuple of (final_carry, outputs)

404

"""

405

406

def associative_scan(fn, elems, reverse=False, axis=0) -> Array:

407

"""

408

Parallel associative scan operation.

409

410

Args:

411

fn: Associative binary function

412

elems: Input sequence

413

reverse: Whether to scan in reverse

414

axis: Axis to scan along

415

416

Returns:

417

Scanned results

418

"""

419

420

def switch(index, branches, *operands) -> Any:

421

"""

422

Switch statement for multi-way branching.

423

424

Args:

425

index: Integer index selecting branch

426

branches: List of functions (branches)

427

operands: Arguments to pass to selected branch

428

429

Returns:

430

Result of executing selected branch

431

"""

432

433

def map(f, xs) -> Array:

434

"""Map function over leading axis of array."""

435

```

436

437

### Cumulative Operations

438

439

Cumulative operations along array axes.

440

441

```python { .api }

442

def cumsum(operand, axis=None, reverse=False) -> Array:

443

"""Cumulative sum along axis."""

444

445

def cumprod(operand, axis=None, reverse=False) -> Array:

446

"""Cumulative product along axis."""

447

448

def cummax(operand, axis=None, reverse=False) -> Array:

449

"""Cumulative maximum along axis."""

450

451

def cummin(operand, axis=None, reverse=False) -> Array:

452

"""Cumulative minimum along axis."""

453

454

def cumlogsumexp(operand, axis=None, reverse=False) -> Array:

455

"""Cumulative log-sum-exp along axis."""

456

```

457

458

### Linear Algebra

459

460

Matrix operations and linear algebra primitives.

461

462

```python { .api }

463

def dot(lhs, rhs, precision=None, preferred_element_type=None) -> Array:

464

"""Matrix multiplication for 1D and 2D arrays."""

465

466

def dot_general(

467

lhs,

468

rhs,

469

dimension_numbers,

470

precision=None,

471

preferred_element_type=None

472

) -> Array:

473

"""

474

General matrix multiplication with custom contractions.

475

476

Args:

477

lhs: Left-hand side array

478

rhs: Right-hand side array

479

dimension_numbers: Specification of contraction and batch dimensions

480

precision: Computation precision

481

preferred_element_type: Preferred output element type

482

483

Returns:

484

Result of general matrix multiplication

485

"""

486

487

def batch_matmul(

488

lhs,

489

rhs,

490

precision=None,

491

preferred_element_type=None

492

) -> Array:

493

"""Batched matrix multiplication."""

494

495

class DotDimensionNumbers:

496

"""Dimension specification for dot_general operation."""

497

lhs_contracting_dimensions: tuple[int, ...]

498

rhs_contracting_dimensions: tuple[int, ...]

499

lhs_batch_dimensions: tuple[int, ...]

500

rhs_batch_dimensions: tuple[int, ...]

501

```

502

503

### Advanced Linear Algebra (lax.linalg)

504

505

Advanced linear algebra operations from `jax.lax.linalg`.

506

507

```python { .api }

508

def cholesky(a, *, symmetrize_input: bool = True) -> Array:

509

"""

510

Cholesky decomposition of positive definite matrix.

511

512

Args:

513

a: Positive definite matrix

514

symmetrize_input: Whether to symmetrize input

515

516

Returns:

517

Lower triangular Cholesky factor

518

"""

519

520

def cholesky_update(r, u, *, alpha: float = 1.0) -> Array:

521

"""

522

Rank-1 update to Cholesky factorization.

523

524

Args:

525

r: Cholesky factor

526

u: Update vector

527

alpha: Update coefficient

528

529

Returns:

530

Updated Cholesky factor

531

"""

532

533

def eig(a, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> tuple[Array, Array, Array]:

534

"""

535

Eigenvalue decomposition of general matrix.

536

537

Args:

538

a: Input matrix

539

compute_left_eigenvectors: Whether to compute left eigenvectors

540

compute_right_eigenvectors: Whether to compute right eigenvectors

541

542

Returns:

543

Tuple of (eigenvalues, left_eigenvectors, right_eigenvectors)

544

"""

545

546

def eigh(a, *, lower: bool = True, symmetrize_input: bool = True, sort_eigenvalues: bool = True) -> tuple[Array, Array]:

547

"""

548

Eigenvalue decomposition of Hermitian matrix.

549

550

Args:

551

a: Hermitian matrix

552

lower: Whether to use lower triangle

553

symmetrize_input: Whether to symmetrize input

554

sort_eigenvalues: Whether to sort eigenvalues

555

556

Returns:

557

Tuple of (eigenvalues, eigenvectors)

558

"""

559

560

def lu(a) -> tuple[Array, Array, Array]:

561

"""

562

LU decomposition with partial pivoting.

563

564

Args:

565

a: Input matrix

566

567

Returns:

568

Tuple of (lu_factors, pivots, permutation)

569

"""

570

571

def qr(a, *, full_matrices: bool = True) -> tuple[Array, Array]:

572

"""

573

QR decomposition.

574

575

Args:

576

a: Input matrix

577

full_matrices: Whether to return full or reduced QR

578

579

Returns:

580

Tuple of (q, r) matrices

581

"""

582

583

def svd(a, *, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> tuple[Array, Array, Array]:

584

"""

585

Singular value decomposition.

586

587

Args:

588

a: Input matrix

589

full_matrices: Whether to return full or reduced SVD

590

compute_uv: Whether to compute U and V matrices

591

hermitian: Whether matrix is Hermitian

592

593

Returns:

594

Tuple of (u, s, vh) where A = U @ diag(s) @ Vh

595

"""

596

597

def schur(a, *, compute_schur_vectors: bool = True, sort_eigs: bool = False, select_callable=None) -> tuple[Array, Array]:

598

"""

599

Schur decomposition.

600

601

Args:

602

a: Input matrix

603

compute_schur_vectors: Whether to compute Schur vectors

604

sort_eigs: Whether to sort eigenvalues

605

select_callable: Selection function for eigenvalues

606

607

Returns:

608

Tuple of (schur_form, schur_vectors)

609

"""

610

611

def hessenberg(a) -> tuple[Array, Array]:

612

"""

613

Hessenberg decomposition.

614

615

Args:

616

a: Input matrix

617

618

Returns:

619

Tuple of (hessenberg_form, orthogonal_matrix)

620

"""

621

622

def triangular_solve(a, b, *, left_side: bool = True, lower: bool = True, transpose_a: bool = False, conjugate_a: bool = False, unit_diagonal: bool = False) -> Array:

623

"""

624

Solve triangular system of equations.

625

626

Args:

627

a: Triangular matrix

628

b: Right-hand side

629

left_side: Whether A is on left side (Ax = b) or right (xA = b)

630

lower: Whether A is lower triangular

631

transpose_a: Whether to transpose A

632

conjugate_a: Whether to conjugate A

633

unit_diagonal: Whether A has unit diagonal

634

635

Returns:

636

Solution to triangular system

637

"""

638

639

def tridiagonal(a, d, *, lower: bool = True) -> tuple[Array, Array]:

640

"""

641

Tridiagonal reduction of symmetric matrix.

642

643

Args:

644

a: Symmetric matrix

645

d: Diagonal elements

646

lower: Whether to use lower triangle

647

648

Returns:

649

Tuple of (tridiagonal_matrix, orthogonal_matrix)

650

"""

651

652

def tridiagonal_solve(dl, d, du, b) -> Array:

653

"""

654

Solve tridiagonal system using Thomas algorithm.

655

656

Args:

657

dl: Lower diagonal

658

d: Main diagonal

659

du: Upper diagonal

660

b: Right-hand side

661

662

Returns:

663

Solution to tridiagonal system

664

"""

665

666

def qdwh(a, *, is_hermitian: bool = False, max_iterations: int = None, dynamic_shape: bool = False) -> tuple[Array, Array]:

667

"""

668

QDWH polar decomposition: A = UP where U is unitary, P is positive semidefinite.

669

670

Args:

671

a: Input matrix

672

is_hermitian: Whether matrix is Hermitian

673

max_iterations: Maximum number of iterations

674

dynamic_shape: Whether to handle dynamic shapes

675

676

Returns:

677

Tuple of (unitary_factor, positive_factor)

678

"""

679

680

def householder_product(a, taus) -> Array:

681

"""

682

Compute product of Householder reflectors.

683

684

Args:

685

a: Matrix containing Householder vectors

686

taus: Householder scaling factors

687

688

Returns:

689

Product of Householder reflectors

690

"""

691

692

def lu_pivots_to_permutation(pivots, permutation_size) -> Array:

693

"""

694

Convert LU pivots to permutation matrix.

695

696

Args:

697

pivots: Pivot indices from LU decomposition

698

permutation_size: Size of permutation matrix

699

700

Returns:

701

Permutation matrix

702

"""

703

```

704

705

### Convolution Operations

706

707

Convolution operations for neural networks and signal processing.

708

709

```python { .api }

710

def conv(

711

lhs,

712

rhs,

713

window_strides,

714

padding,

715

precision=None,

716

preferred_element_type=None

717

) -> Array:

718

"""Basic convolution operation."""

719

720

def conv_general_dilated(

721

lhs,

722

rhs,

723

window_strides,

724

padding,

725

lhs_dilation=None,

726

rhs_dilation=None,

727

dimension_numbers=None,

728

feature_group_count=1,

729

batch_group_count=1,

730

precision=None,

731

preferred_element_type=None

732

) -> Array:

733

"""

734

General dilated convolution with full configuration options.

735

736

Args:

737

lhs: Input array (N...HWC or NCHW... format)

738

rhs: Kernel array

739

window_strides: Convolution strides

740

padding: Padding specification

741

lhs_dilation: Input dilation

742

rhs_dilation: Kernel dilation (atrous convolution)

743

dimension_numbers: Dimension layout specification

744

feature_group_count: Number of feature groups

745

batch_group_count: Number of batch groups

746

precision: Computation precision

747

preferred_element_type: Preferred output type

748

749

Returns:

750

Convolution result

751

"""

752

753

def conv_transpose(

754

lhs,

755

rhs,

756

strides,

757

padding,

758

rhs_dilation=None,

759

dimension_numbers=None,

760

transpose_kernel=False,

761

precision=None,

762

preferred_element_type=None

763

) -> Array:

764

"""Transposed (deconvolution) operation."""

765

766

class ConvDimensionNumbers:

767

"""Convolution dimension number specification."""

768

lhs_spec: tuple[int, ...] # Input dimension specification

769

rhs_spec: tuple[int, ...] # Kernel dimension specification

770

out_spec: tuple[int, ...] # Output dimension specification

771

```

772

773

### FFT Operations

774

775

Fast Fourier Transform operations.

776

777

```python { .api }

778

def fft(a, fft_type, fft_lengths) -> Array:

779

"""

780

Fast Fourier Transform.

781

782

Args:

783

a: Input array

784

fft_type: Type of FFT (from FftType enum)

785

fft_lengths: Lengths of FFT dimensions

786

787

Returns:

788

FFT result

789

"""

790

791

class FftType:

792

"""FFT type enumeration."""

793

FFT = "FFT"

794

IFFT = "IFFT"

795

RFFT = "RFFT"

796

IRFFT = "IRFFT"

797

```

798

799

### Parallel Operations

800

801

Multi-device communication primitives for distributed computing.

802

803

```python { .api }

804

def all_gather(x, axis_name, *, axis_index_groups=None, tiled=False) -> Array:

805

"""Gather values from all devices."""

806

807

def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False) -> Array:

808

"""All-to-all communication between devices."""

809

810

def psum(x, axis_name, *, axis_index_groups=None) -> Array:

811

"""Parallel sum reduction across devices."""

812

813

def pmean(x, axis_name, *, axis_index_groups=None) -> Array:

814

"""Parallel mean reduction across devices."""

815

816

def pmax(x, axis_name, *, axis_index_groups=None) -> Array:

817

"""Parallel max reduction across devices."""

818

819

def pmin(x, axis_name, *, axis_index_groups=None) -> Array:

820

"""Parallel min reduction across devices."""

821

822

def ppermute(x, axis_name, perm, *, axis_index_groups=None) -> Array:

823

"""Permute data between devices."""

824

825

def axis_index(axis_name) -> Array:

826

"""Get device index along named axis."""

827

828

def axis_size(axis_name) -> int:

829

"""Get number of devices along named axis."""

830

831

def pbroadcast(x, axis_name, *, axis_index_groups=None) -> Array:

832

"""Broadcast from first device to all others."""

833

```

834

835

### Special Functions

836

837

Special mathematical functions and probability distributions.

838

839

```python { .api }

840

# Error functions

841

def erf(x) -> Array: ...

842

def erfc(x) -> Array: ...

843

def erf_inv(x) -> Array: ...

844

845

# Gamma functions

846

def lgamma(x) -> Array: ...

847

def digamma(x) -> Array: ...

848

def polygamma(m, x) -> Array: ...

849

850

# Bessel functions

851

def bessel_i0e(x) -> Array: ...

852

def bessel_i1e(x) -> Array: ...

853

854

# Other special functions

855

def betainc(a, b, x) -> Array: ...

856

def igamma(a, x) -> Array: ...

857

def igammac(a, x) -> Array: ...

858

def zeta(x, q=None) -> Array: ...

859

```

860

861

### Type Conversion and Manipulation

862

863

Array type conversion and data manipulation operations.

864

865

```python { .api }

866

def convert_element_type(operand, new_dtype) -> Array:

867

"""Convert array element type."""

868

869

def bitcast_convert_type(operand, new_dtype) -> Array:

870

"""Bitcast array to new type without changing bit representation."""

871

872

def dtype(x) -> numpy.dtype:

873

"""Get array data type."""

874

875

def full(shape, fill_value, dtype=None) -> Array:

876

"""Create array filled with constant value."""

877

878

def full_like(x, fill_value, dtype=None, shape=None) -> Array:

879

"""Create filled array with same properties as input."""

880

881

def iota(dtype, size) -> Array:

882

"""Create array with sequential values (0, 1, 2, ...)."""

883

884

def broadcasted_iota(dtype, shape, dimension) -> Array:

885

"""Create iota array broadcasted to shape."""

886

```

887

888

### Sorting Operations

889

890

Sorting and selection operations.

891

892

```python { .api }

893

def sort(operand, dimension=-1, is_stable=True) -> Array:

894

"""Sort array along dimension."""

895

896

def sort_key_val(keys, values, dimension=-1, is_stable=True) -> tuple[Array, Array]:

897

"""Sort key-value pairs."""

898

899

def top_k(operand, k) -> tuple[Array, Array]:

900

"""Find top k largest elements and their indices."""

901

902

def argmax(operand, axis=None, index_dtype=int) -> Array:

903

"""Indices of maximum values."""

904

905

def argmin(operand, axis=None, index_dtype=int) -> Array:

906

"""Indices of minimum values."""

907

```

908

909

### Miscellaneous Operations

910

911

Additional utility operations and performance primitives.

912

913

```python { .api }

914

def stop_gradient(x) -> Array:

915

"""Stop gradient computation at this point."""

916

917

def optimization_barrier(x) -> Array:

918

"""Prevent optimization across this point."""

919

920

def nextafter(x1, x2) -> Array:

921

"""Next representable value after x1 in direction of x2."""

922

923

def reduce_precision(operand, exponent_bits, mantissa_bits) -> Array:

924

"""Reduce floating-point precision."""

925

926

def create_token() -> Array:

927

"""Create execution token for ordering side effects."""

928

929

def after_all(*tokens) -> Array:

930

"""Create token that depends on all input tokens."""

931

932

# Random number generation primitives

933

def rng_uniform(a, b, shape, dtype=None) -> Array:

934

"""Low-level uniform random number generation."""

935

936

def rng_bit_generator(key, shape, dtype=None, algorithm=None) -> tuple[Array, Array]:

937

"""Low-level random bit generation."""

938

```