or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

scipy-compatibility.mddocs/

0

# SciPy Compatibility

1

2

JAX provides SciPy-compatible functions through `jax.scipy` for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. These functions are differentiable and can be JIT-compiled.

3

4

## Core Imports

5

6

```python

7

import jax.scipy as jsp

8

import jax.scipy.linalg as jla

9

import jax.scipy.special as jss

10

import jax.scipy.stats as jst

11

```

12

13

## Capabilities

14

15

### Linear Algebra (`jax.scipy.linalg`)

16

17

Advanced linear algebra operations for matrix computations and decompositions.

18

19

```python { .api }

20

# Matrix decompositions

21

def cholesky(a, lower=True) -> Array:

22

"""

23

Cholesky decomposition of positive definite matrix.

24

25

Args:

26

a: Positive definite matrix to decompose

27

lower: Whether to return lower triangular factor

28

29

Returns:

30

Cholesky factor L such that a = L @ L.T (or U.T @ U if upper)

31

"""

32

33

def qr(a, mode='reduced') -> tuple[Array, Array]:

34

"""

35

QR decomposition of matrix.

36

37

Args:

38

a: Matrix to decompose

39

mode: 'reduced' or 'complete' decomposition

40

41

Returns:

42

Tuple (Q, R) where Q is orthogonal and R is upper triangular

43

"""

44

45

def svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]:

46

"""

47

Singular Value Decomposition.

48

49

Args:

50

a: Matrix to decompose

51

full_matrices: Whether to compute full or reduced SVD

52

compute_uv: Whether to compute U and V matrices

53

hermitian: Whether matrix is Hermitian

54

55

Returns:

56

Tuple (U, s, Vh) where a = U @ diag(s) @ Vh

57

"""

58

59

def eig(a, b=None, left=False, right=True, overwrite_a=False, overwrite_b=False,

60

check_finite=True, homogeneous_eigvals=False) -> tuple[Array, Array]:

61

"""

62

Eigenvalues and eigenvectors of general matrix.

63

64

Args:

65

a: Square matrix

66

b: Optional matrix for generalized eigenvalue problem

67

left: Whether to compute left eigenvectors

68

right: Whether to compute right eigenvectors

69

overwrite_a: Whether input can be overwritten

70

overwrite_b: Whether b can be overwritten

71

check_finite: Whether to check for finite values

72

homogeneous_eigvals: Whether to return homogeneous eigenvalues

73

74

Returns:

75

Tuple (eigenvalues, eigenvectors)

76

"""

77

78

def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,

79

overwrite_b=False, turbo=True, eigvals=None, type=1,

80

check_finite=True) -> tuple[Array, Array]:

81

"""

82

Eigenvalues and eigenvectors of Hermitian matrix.

83

84

Args:

85

a: Hermitian matrix

86

b: Optional matrix for generalized problem

87

lower: Whether to use lower triangle

88

eigvals_only: Whether to compute eigenvalues only

89

overwrite_a: Whether input can be overwritten

90

overwrite_b: Whether b can be overwritten

91

turbo: Whether to use turbo algorithm

92

eigvals: Range of eigenvalue indices to compute

93

type: Type of generalized eigenvalue problem

94

check_finite: Whether to check for finite values

95

96

Returns:

97

Eigenvalues (and eigenvectors if eigvals_only=False)

98

"""

99

100

def eigvals(a, b=None, overwrite_a=False, check_finite=True,

101

homogeneous_eigvals=False) -> Array:

102

"""Eigenvalues of general matrix."""

103

104

def eigvalsh(a, b=None, lower=True, overwrite_a=False, overwrite_b=False,

105

turbo=True, eigvals=None, type=1, check_finite=True) -> Array:

106

"""Eigenvalues of Hermitian matrix."""

107

108

# Matrix properties and functions

109

def det(a) -> Array:

110

"""Matrix determinant."""

111

112

def slogdet(a) -> tuple[Array, Array]:

113

"""Sign and log determinant of matrix."""

114

115

def logdet(a) -> Array:

116

"""Log determinant of matrix."""

117

118

def matrix_rank(M, tol=None, hermitian=False) -> Array:

119

"""Matrix rank computation."""

120

121

def trace(a, offset=0, axis1=0, axis2=1) -> Array:

122

"""Matrix trace."""

123

124

def norm(a, ord=None, axis=None, keepdims=False) -> Array:

125

"""Matrix or vector norm."""

126

127

def cond(x, p=None) -> Array:

128

"""Condition number of matrix."""

129

130

# Matrix solutions

131

def solve(a, b, assume_a='gen', lower=False, overwrite_a=False,

132

overwrite_b=False, debug=None, check_finite=True) -> Array:

133

"""

134

Solve linear system Ax = b.

135

136

Args:

137

a: Coefficient matrix

138

b: Right-hand side vector/matrix

139

assume_a: Properties of matrix a ('gen', 'sym', 'her', 'pos')

140

lower: Whether to use lower triangle for triangular matrices

141

overwrite_a: Whether input can be overwritten

142

overwrite_b: Whether b can be overwritten

143

debug: Debug information level

144

check_finite: Whether to check for finite values

145

146

Returns:

147

Solution x such that Ax = b

148

"""

149

150

def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,

151

overwrite_b=False, debug=None, check_finite=True) -> Array:

152

"""Solve triangular linear system."""

153

154

def inv(a, overwrite_a=False, check_finite=True) -> Array:

155

"""Matrix inverse."""

156

157

def pinv(a, rcond=None, hermitian=False, return_rank=False) -> Array:

158

"""Moore-Penrose pseudoinverse."""

159

160

def lstsq(a, b, rcond=None, lapack_driver=None) -> tuple[Array, Array, Array, Array]:

161

"""

162

Least-squares solution to linear system.

163

164

Args:

165

a: Coefficient matrix

166

b: Dependent variable values

167

rcond: Cutoff ratio for small singular values

168

lapack_driver: LAPACK driver to use

169

170

Returns:

171

Tuple (solution, residuals, rank, singular_values)

172

"""

173

174

# Matrix functions

175

def expm(A) -> Array:

176

"""Matrix exponential."""

177

178

def funm(A, func, disp=True) -> Array:

179

"""General matrix function evaluation."""

180

181

def sqrtm(A, disp=True, blocksize=64) -> Array:

182

"""Matrix square root."""

183

184

def logm(A, disp=True) -> Array:

185

"""Matrix logarithm."""

186

187

def fractional_matrix_power(A, t) -> Array:

188

"""Fractional matrix power A^t."""

189

190

def matrix_power(A, n) -> Array:

191

"""Integer matrix power A^n."""

192

193

# Schur decomposition

194

def schur(a, output='real') -> tuple[Array, Array]:

195

"""Schur decomposition of matrix."""

196

197

def rsf2csf(T, Z) -> tuple[Array, Array]:

198

"""Convert real Schur form to complex Schur form."""

199

200

# Polar decomposition

201

def polar(a, side='right') -> tuple[Array, Array]:

202

"""Polar decomposition of matrix."""

203

```

204

205

### Special Functions (`jax.scipy.special`)

206

207

Special mathematical functions including error functions, gamma functions, and Bessel functions.

208

209

```python { .api }

210

# Error functions

211

def erf(z) -> Array:

212

"""Error function."""

213

214

def erfc(x) -> Array:

215

"""Complementary error function."""

216

217

def erfinv(y) -> Array:

218

"""Inverse error function."""

219

220

def erfcinv(y) -> Array:

221

"""Inverse complementary error function."""

222

223

def wofz(z) -> Array:

224

"""Faddeeva function."""

225

226

# Gamma functions

227

def gamma(z) -> Array:

228

"""Gamma function."""

229

230

def gammaln(x) -> Array:

231

"""Log gamma function."""

232

233

def digamma(x) -> Array:

234

"""Digamma (psi) function."""

235

236

def polygamma(n, x) -> Array:

237

"""Polygamma function."""

238

239

def gammainc(a, x) -> Array:

240

"""Lower incomplete gamma function."""

241

242

def gammaincc(a, x) -> Array:

243

"""Upper incomplete gamma function."""

244

245

def gammasgn(x) -> Array:

246

"""Sign of gamma function."""

247

248

def rgamma(x) -> Array:

249

"""Reciprocal gamma function."""

250

251

# Beta functions

252

def beta(a, b) -> Array:

253

"""Beta function."""

254

255

def betaln(a, b) -> Array:

256

"""Log beta function."""

257

258

def betainc(a, b, x) -> Array:

259

"""Incomplete beta function."""

260

261

# Bessel functions

262

def j0(x) -> Array:

263

"""Bessel function of the first kind of order 0."""

264

265

def j1(x) -> Array:

266

"""Bessel function of the first kind of order 1."""

267

268

def jn(n, x) -> Array:

269

"""Bessel function of the first kind of order n."""

270

271

def y0(x) -> Array:

272

"""Bessel function of the second kind of order 0."""

273

274

def y1(x) -> Array:

275

"""Bessel function of the second kind of order 1."""

276

277

def yn(n, x) -> Array:

278

"""Bessel function of the second kind of order n."""

279

280

def i0(x) -> Array:

281

"""Modified Bessel function of the first kind of order 0."""

282

283

def i0e(x) -> Array:

284

"""Exponentially scaled modified Bessel function i0."""

285

286

def i1(x) -> Array:

287

"""Modified Bessel function of the first kind of order 1."""

288

289

def i1e(x) -> Array:

290

"""Exponentially scaled modified Bessel function i1."""

291

292

def iv(v, z) -> Array:

293

"""Modified Bessel function of the first kind of real order."""

294

295

def k0(x) -> Array:

296

"""Modified Bessel function of the second kind of order 0."""

297

298

def k0e(x) -> Array:

299

"""Exponentially scaled modified Bessel function k0."""

300

301

def k1(x) -> Array:

302

"""Modified Bessel function of the second kind of order 1."""

303

304

def k1e(x) -> Array:

305

"""Exponentially scaled modified Bessel function k1."""

306

307

def kv(v, z) -> Array:

308

"""Modified Bessel function of the second kind of real order."""

309

310

# Exponential integrals

311

def expi(x) -> Array:

312

"""Exponential integral Ei."""

313

314

def expn(n, x) -> Array:

315

"""Generalized exponential integral."""

316

317

# Log-sum-exp and related

318

def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False) -> Array:

319

"""

320

Compute log(sum(exp(a))) in numerically stable way.

321

322

Args:

323

a: Input array

324

axis: Axis to sum over

325

b: Multiplier for each element

326

keepdims: Whether to keep reduced dimensions

327

return_sign: Whether to return sign separately

328

329

Returns:

330

Log-sum-exp result

331

"""

332

333

def softmax(x, axis=None) -> Array:

334

"""Softmax function."""

335

336

def log_softmax(x, axis=None) -> Array:

337

"""Log softmax function."""

338

339

# Combinatorial functions

340

def factorial(n, exact=False) -> Array:

341

"""Factorial function."""

342

343

def factorial2(n, exact=False) -> Array:

344

"""Double factorial function."""

345

346

def factorialk(n, k, exact=False) -> Array:

347

"""Multifactorial function."""

348

349

def comb(N, k, exact=False, repetition=False) -> Array:

350

"""Binomial coefficient."""

351

352

def perm(N, k, exact=False) -> Array:

353

"""Permutation coefficient."""

354

355

# Elliptic integrals

356

def ellipk(m) -> Array:

357

"""Complete elliptic integral of the first kind."""

358

359

def ellipe(m) -> Array:

360

"""Complete elliptic integral of the second kind."""

361

362

def ellipkinc(phi, m) -> Array:

363

"""Incomplete elliptic integral of the first kind."""

364

365

def ellipeinc(phi, m) -> Array:

366

"""Incomplete elliptic integral of the second kind."""

367

368

# Zeta and related functions

369

def zeta(x, q=None) -> Array:

370

"""Riemann or Hurwitz zeta function."""

371

372

def zetac(x) -> Array:

373

"""Riemann zeta function minus 1."""

374

375

# Hypergeometric functions

376

def hyp1f1(a, b, x) -> Array:

377

"""Confluent hypergeometric function 1F1."""

378

379

def hyp2f1(a, b, c, z) -> Array:

380

"""Gaussian hypergeometric function 2F1."""

381

382

def hyperu(a, b, x) -> Array:

383

"""Confluent hypergeometric function U."""

384

385

# Legendre functions

386

def legendre(n, x) -> Array:

387

"""Legendre polynomial."""

388

389

def lpmv(m, v, x) -> Array:

390

"""Associated Legendre function."""

391

392

# Spherical functions

393

def sph_harm(m, n, theta, phi) -> Array:

394

"""Spherical harmonics."""

395

396

# Other special functions

397

def lambertw(z, k=0, tol=1e-8) -> Array:

398

"""Lambert W function."""

399

400

def spence(z) -> Array:

401

"""Spence function."""

402

403

def multigammaln(a, d) -> Array:

404

"""Log of multivariate gamma function."""

405

406

def entr(x) -> Array:

407

"""Elementwise function -x*log(x)."""

408

409

def kl_div(x, y) -> Array:

410

"""Elementwise function x*log(x/y) - x + y."""

411

412

def rel_entr(x, y) -> Array:

413

"""Elementwise function x*log(x/y)."""

414

415

def huber(delta, r) -> Array:

416

"""Huber loss function."""

417

418

def pseudo_huber(delta, r) -> Array:

419

"""Pseudo-Huber loss function."""

420

```

421

422

### Statistics (`jax.scipy.stats`)

423

424

Statistical distributions and functions for probability and hypothesis testing.

425

426

```python { .api }

427

# Continuous distributions

428

class norm:

429

"""Normal distribution."""

430

@staticmethod

431

def pdf(x, loc=0, scale=1) -> Array: ...

432

@staticmethod

433

def logpdf(x, loc=0, scale=1) -> Array: ...

434

@staticmethod

435

def cdf(x, loc=0, scale=1) -> Array: ...

436

@staticmethod

437

def logcdf(x, loc=0, scale=1) -> Array: ...

438

@staticmethod

439

def sf(x, loc=0, scale=1) -> Array: ...

440

@staticmethod

441

def logsf(x, loc=0, scale=1) -> Array: ...

442

@staticmethod

443

def ppf(q, loc=0, scale=1) -> Array: ...

444

@staticmethod

445

def isf(q, loc=0, scale=1) -> Array: ...

446

447

class multivariate_normal:

448

"""Multivariate normal distribution."""

449

@staticmethod

450

def pdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...

451

@staticmethod

452

def logpdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...

453

454

class uniform:

455

"""Uniform distribution."""

456

@staticmethod

457

def pdf(x, loc=0, scale=1) -> Array: ...

458

@staticmethod

459

def logpdf(x, loc=0, scale=1) -> Array: ...

460

@staticmethod

461

def cdf(x, loc=0, scale=1) -> Array: ...

462

@staticmethod

463

def logcdf(x, loc=0, scale=1) -> Array: ...

464

@staticmethod

465

def sf(x, loc=0, scale=1) -> Array: ...

466

@staticmethod

467

def logsf(x, loc=0, scale=1) -> Array: ...

468

@staticmethod

469

def ppf(q, loc=0, scale=1) -> Array: ...

470

471

class beta:

472

"""Beta distribution."""

473

@staticmethod

474

def pdf(x, a, b, loc=0, scale=1) -> Array: ...

475

@staticmethod

476

def logpdf(x, a, b, loc=0, scale=1) -> Array: ...

477

@staticmethod

478

def cdf(x, a, b, loc=0, scale=1) -> Array: ...

479

480

class gamma:

481

"""Gamma distribution."""

482

@staticmethod

483

def pdf(x, a, loc=0, scale=1) -> Array: ...

484

@staticmethod

485

def logpdf(x, a, loc=0, scale=1) -> Array: ...

486

@staticmethod

487

def cdf(x, a, loc=0, scale=1) -> Array: ...

488

489

class chi2:

490

"""Chi-square distribution."""

491

@staticmethod

492

def pdf(x, df, loc=0, scale=1) -> Array: ...

493

@staticmethod

494

def logpdf(x, df, loc=0, scale=1) -> Array: ...

495

@staticmethod

496

def cdf(x, df, loc=0, scale=1) -> Array: ...

497

498

class t:

499

"""Student's t-distribution."""

500

@staticmethod

501

def pdf(x, df, loc=0, scale=1) -> Array: ...

502

@staticmethod

503

def logpdf(x, df, loc=0, scale=1) -> Array: ...

504

@staticmethod

505

def cdf(x, df, loc=0, scale=1) -> Array: ...

506

507

class f:

508

"""F-distribution."""

509

@staticmethod

510

def pdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...

511

@staticmethod

512

def logpdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...

513

@staticmethod

514

def cdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...

515

516

class laplace:

517

"""Laplace distribution."""

518

@staticmethod

519

def pdf(x, loc=0, scale=1) -> Array: ...

520

@staticmethod

521

def logpdf(x, loc=0, scale=1) -> Array: ...

522

@staticmethod

523

def cdf(x, loc=0, scale=1) -> Array: ...

524

525

class logistic:

526

"""Logistic distribution."""

527

@staticmethod

528

def pdf(x, loc=0, scale=1) -> Array: ...

529

@staticmethod

530

def logpdf(x, loc=0, scale=1) -> Array: ...

531

@staticmethod

532

def cdf(x, loc=0, scale=1) -> Array: ...

533

534

class pareto:

535

"""Pareto distribution."""

536

@staticmethod

537

def pdf(x, b, loc=0, scale=1) -> Array: ...

538

@staticmethod

539

def logpdf(x, b, loc=0, scale=1) -> Array: ...

540

@staticmethod

541

def cdf(x, b, loc=0, scale=1) -> Array: ...

542

543

class expon:

544

"""Exponential distribution."""

545

@staticmethod

546

def pdf(x, loc=0, scale=1) -> Array: ...

547

@staticmethod

548

def logpdf(x, loc=0, scale=1) -> Array: ...

549

@staticmethod

550

def cdf(x, loc=0, scale=1) -> Array: ...

551

552

class lognorm:

553

"""Log-normal distribution."""

554

@staticmethod

555

def pdf(x, s, loc=0, scale=1) -> Array: ...

556

@staticmethod

557

def logpdf(x, s, loc=0, scale=1) -> Array: ...

558

@staticmethod

559

def cdf(x, s, loc=0, scale=1) -> Array: ...

560

561

class truncnorm:

562

"""Truncated normal distribution."""

563

@staticmethod

564

def pdf(x, a, b, loc=0, scale=1) -> Array: ...

565

@staticmethod

566

def logpdf(x, a, b, loc=0, scale=1) -> Array: ...

567

@staticmethod

568

def cdf(x, a, b, loc=0, scale=1) -> Array: ...

569

570

# Discrete distributions

571

class bernoulli:

572

"""Bernoulli distribution."""

573

@staticmethod

574

def pmf(k, p, loc=0) -> Array: ...

575

@staticmethod

576

def logpmf(k, p, loc=0) -> Array: ...

577

@staticmethod

578

def cdf(k, p, loc=0) -> Array: ...

579

580

class binom:

581

"""Binomial distribution."""

582

@staticmethod

583

def pmf(k, n, p, loc=0) -> Array: ...

584

@staticmethod

585

def logpmf(k, n, p, loc=0) -> Array: ...

586

@staticmethod

587

def cdf(k, n, p, loc=0) -> Array: ...

588

589

class geom:

590

"""Geometric distribution."""

591

@staticmethod

592

def pmf(k, p, loc=0) -> Array: ...

593

@staticmethod

594

def logpmf(k, p, loc=0) -> Array: ...

595

@staticmethod

596

def cdf(k, p, loc=0) -> Array: ...

597

598

class nbinom:

599

"""Negative binomial distribution."""

600

@staticmethod

601

def pmf(k, n, p, loc=0) -> Array: ...

602

@staticmethod

603

def logpmf(k, n, p, loc=0) -> Array: ...

604

@staticmethod

605

def cdf(k, n, p, loc=0) -> Array: ...

606

607

class poisson:

608

"""Poisson distribution."""

609

@staticmethod

610

def pmf(k, mu, loc=0) -> Array: ...

611

@staticmethod

612

def logpmf(k, mu, loc=0) -> Array: ...

613

@staticmethod

614

def cdf(k, mu, loc=0) -> Array: ...

615

616

# Statistical functions

617

def mode(a, axis=0, nan_policy='propagate', keepdims=False) -> Array:

618

"""Mode of array values along axis."""

619

620

def rankdata(a, method='average', axis=None) -> Array:

621

"""Rank data along axis."""

622

623

def kendalltau(x, y, initial_lexsort=None, nan_policy='propagate', method='auto') -> tuple[Array, Array]:

624

"""Kendall's tau correlation coefficient."""

625

626

def pearsonr(x, y) -> tuple[Array, Array]:

627

"""Pearson correlation coefficient."""

628

629

def spearmanr(a, b=None, axis=0, nan_policy='propagate', alternative='two-sided') -> tuple[Array, Array]:

630

"""Spearman correlation coefficient."""

631

```

632

633

### Signal Processing (`jax.scipy.signal`)

634

635

Signal processing functions for filtering, convolution, and spectral analysis.

636

637

```python { .api }

638

def convolve(in1, in2, mode='full', method='auto') -> Array:

639

"""N-dimensional convolution."""

640

641

def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:

642

"""2D convolution."""

643

644

def correlate(in1, in2, mode='full', method='auto') -> Array:

645

"""Cross-correlation of two arrays."""

646

647

def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:

648

"""2D cross-correlation."""

649

650

def fftconvolve(in1, in2, mode='full', axes=None) -> Array:

651

"""FFT-based convolution."""

652

653

def oaconvolve(in1, in2, mode='full', axes=None) -> Array:

654

"""Overlap-add convolution."""

655

656

def lfilter(b, a, x, axis=-1, zi=None) -> Array:

657

"""Linear digital filter."""

658

659

def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None) -> Array:

660

"""Zero-phase digital filtering."""

661

662

def sosfilt(sos, x, axis=-1, zi=None) -> Array:

663

"""Filter using second-order sections."""

664

665

def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None) -> Array:

666

"""Zero-phase filtering with second-order sections."""

667

668

def hilbert(x, N=None, axis=-1) -> Array:

669

"""Hilbert transform."""

670

671

def hilbert2(x, N=None) -> Array:

672

"""2D Hilbert transform."""

673

674

def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True) -> Array:

675

"""Downsample signal by integer factor."""

676

677

def resample(x, num, t=None, axis=0, window=None, domain='time') -> Array:

678

"""Resample signal to new sample rate."""

679

680

def resample_poly(x, up, down, axis=0, window='kaiser', padtype='constant', cval=None) -> Array:

681

"""Resample using polyphase filtering."""

682

683

def upfirdn(h, x, up=1, down=1, axis=-1, mode='constant', cval=0) -> Array:

684

"""Upsample, FIR filter, and downsample."""

685

686

def periodogram(x, fs=1.0, window='boxcar', nfft=None, detrend='constant',

687

return_onesided=True, scaling='density', axis=-1) -> tuple[Array, Array]:

688

"""Periodogram power spectral density."""

689

690

def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,

691

detrend='constant', return_onesided=True, scaling='density', axis=-1,

692

average='mean') -> tuple[Array, Array]:

693

"""Welch's method for power spectral density."""

694

695

def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,

696

detrend='constant', return_onesided=True, scaling='density', axis=-1,

697

average='mean') -> tuple[Array, Array]:

698

"""Cross power spectral density."""

699

700

def coherence(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,

701

detrend='constant', axis=-1) -> tuple[Array, Array]:

702

"""Coherence between signals."""

703

704

def spectrogram(x, fs=1.0, window='tukey', nperseg=None, noverlap=None, nfft=None,

705

detrend='constant', return_onesided=True, scaling='density', axis=-1,

706

mode='psd') -> tuple[Array, Array, Array]:

707

"""Spectrogram using short-time Fourier transform."""

708

709

def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,

710

detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1) -> tuple[Array, Array, Array]:

711

"""Short-time Fourier transform."""

712

713

def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,

714

input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2) -> tuple[Array, Array]:

715

"""Inverse short-time Fourier transform."""

716

717

def lombscargle(x, y, freqs, precenter=False, normalize=False) -> Array:

718

"""Lomb-Scargle periodogram."""

719

720

def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False) -> Array:

721

"""Remove linear trend from data."""

722

723

def find_peaks(x, height=None, threshold=None, distance=None, prominence=None,

724

width=None, wlen=None, rel_height=0.5, plateau_size=None) -> tuple[Array, dict]:

725

"""Find peaks in 1D array."""

726

727

def peak_prominences(x, peaks, wlen=None) -> tuple[Array, Array, Array]:

728

"""Calculate peak prominences."""

729

730

def peak_widths(x, peaks, rel_height=0.5, prominence_data=None, wlen=None) -> tuple[Array, Array, Array, Array]:

731

"""Calculate peak widths."""

732

```

733

734

### Other Submodules

735

736

```python { .api }

737

# Fast Fourier Transform (jax.scipy.fft)

738

import jax.scipy.fft as jfft

739

# Same interface as jax.numpy.fft with additional functions

740

741

# N-dimensional image processing (jax.scipy.ndimage)

742

import jax.scipy.ndimage as jnd

743

# Image filtering, morphology, and measurements

744

745

# Sparse matrix operations (jax.scipy.sparse)

746

import jax.scipy.sparse as jss

747

# Sparse matrix formats and operations

748

749

# Interpolation (jax.scipy.interpolate)

750

import jax.scipy.interpolate as jsi

751

# 1D and multidimensional interpolation

752

753

# Clustering (jax.scipy.cluster)

754

import jax.scipy.cluster as jsc

755

# Hierarchical and k-means clustering

756

757

# Integration and ODE solving (jax.scipy.integrate)

758

import jax.scipy.integrate as jsi

759

# Numerical integration and differential equation solving

760

```

761

762

## Usage Examples

763

764

```python

765

import jax.numpy as jnp

766

import jax.scipy as jsp

767

import jax.scipy.linalg as jla

768

import jax.scipy.special as jss

769

import jax.scipy.stats as jst

770

771

# Linear algebra example

772

A = jnp.array([[4.0, 2.0], [2.0, 3.0]])

773

b = jnp.array([1.0, 2.0])

774

775

# Solve linear system

776

x = jla.solve(A, b)

777

778

# Compute eigenvalues and eigenvectors

779

eigenvals, eigenvecs = jla.eigh(A)

780

781

# Matrix decomposition

782

L = jla.cholesky(A) # A = L @ L.T

783

784

# Special functions

785

x = jnp.linspace(-3, 3, 100)

786

erf_vals = jss.erf(x)

787

gamma_vals = jss.gamma(x + 1)

788

789

# Statistical distributions

790

data = jnp.array([1.2, 2.3, 1.8, 3.1, 2.7])

791

log_likelihood = jst.norm.logpdf(data, loc=2.0, scale=1.0).sum()

792

793

# Probability density functions

794

x_vals = jnp.linspace(0, 5, 100)

795

pdf_vals = jst.gamma.pdf(x_vals, a=2.0, scale=1.0)

796

797

# Use in optimization with JAX transformations

798

@jax.jit

799

def neg_log_likelihood(params, data):

800

mu, sigma = params

801

return -jst.norm.logpdf(data, mu, sigma).sum()

802

803

# Compute gradient for maximum likelihood estimation

804

grad_fn = jax.grad(neg_log_likelihood)

805

gradients = grad_fn([2.0, 1.0], data)

806

```