0
# SciPy Compatibility
1
2
JAX provides SciPy-compatible functions through `jax.scipy` for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. These functions are differentiable and can be JIT-compiled.
3
4
## Core Imports
5
6
```python
7
import jax.scipy as jsp
8
import jax.scipy.linalg as jla
9
import jax.scipy.special as jss
10
import jax.scipy.stats as jst
11
```
12
13
## Capabilities
14
15
### Linear Algebra (`jax.scipy.linalg`)
16
17
Advanced linear algebra operations for matrix computations and decompositions.
18
19
```python { .api }
20
# Matrix decompositions
21
def cholesky(a, lower=True) -> Array:
22
"""
23
Cholesky decomposition of positive definite matrix.
24
25
Args:
26
a: Positive definite matrix to decompose
27
lower: Whether to return lower triangular factor
28
29
Returns:
30
Cholesky factor L such that a = L @ L.T (or U.T @ U if upper)
31
"""
32
33
def qr(a, mode='reduced') -> tuple[Array, Array]:
34
"""
35
QR decomposition of matrix.
36
37
Args:
38
a: Matrix to decompose
39
mode: 'reduced' or 'complete' decomposition
40
41
Returns:
42
Tuple (Q, R) where Q is orthogonal and R is upper triangular
43
"""
44
45
def svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]:
46
"""
47
Singular Value Decomposition.
48
49
Args:
50
a: Matrix to decompose
51
full_matrices: Whether to compute full or reduced SVD
52
compute_uv: Whether to compute U and V matrices
53
hermitian: Whether matrix is Hermitian
54
55
Returns:
56
Tuple (U, s, Vh) where a = U @ diag(s) @ Vh
57
"""
58
59
def eig(a, b=None, left=False, right=True, overwrite_a=False, overwrite_b=False,
60
check_finite=True, homogeneous_eigvals=False) -> tuple[Array, Array]:
61
"""
62
Eigenvalues and eigenvectors of general matrix.
63
64
Args:
65
a: Square matrix
66
b: Optional matrix for generalized eigenvalue problem
67
left: Whether to compute left eigenvectors
68
right: Whether to compute right eigenvectors
69
overwrite_a: Whether input can be overwritten
70
overwrite_b: Whether b can be overwritten
71
check_finite: Whether to check for finite values
72
homogeneous_eigvals: Whether to return homogeneous eigenvalues
73
74
Returns:
75
Tuple (eigenvalues, eigenvectors)
76
"""
77
78
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
79
overwrite_b=False, turbo=True, eigvals=None, type=1,
80
check_finite=True) -> tuple[Array, Array]:
81
"""
82
Eigenvalues and eigenvectors of Hermitian matrix.
83
84
Args:
85
a: Hermitian matrix
86
b: Optional matrix for generalized problem
87
lower: Whether to use lower triangle
88
eigvals_only: Whether to compute eigenvalues only
89
overwrite_a: Whether input can be overwritten
90
overwrite_b: Whether b can be overwritten
91
turbo: Whether to use turbo algorithm
92
eigvals: Range of eigenvalue indices to compute
93
type: Type of generalized eigenvalue problem
94
check_finite: Whether to check for finite values
95
96
Returns:
97
Eigenvalues (and eigenvectors if eigvals_only=False)
98
"""
99
100
def eigvals(a, b=None, overwrite_a=False, check_finite=True,
101
homogeneous_eigvals=False) -> Array:
102
"""Eigenvalues of general matrix."""
103
104
def eigvalsh(a, b=None, lower=True, overwrite_a=False, overwrite_b=False,
105
turbo=True, eigvals=None, type=1, check_finite=True) -> Array:
106
"""Eigenvalues of Hermitian matrix."""
107
108
# Matrix properties and functions
109
def det(a) -> Array:
110
"""Matrix determinant."""
111
112
def slogdet(a) -> tuple[Array, Array]:
113
"""Sign and log determinant of matrix."""
114
115
def logdet(a) -> Array:
116
"""Log determinant of matrix."""
117
118
def matrix_rank(M, tol=None, hermitian=False) -> Array:
119
"""Matrix rank computation."""
120
121
def trace(a, offset=0, axis1=0, axis2=1) -> Array:
122
"""Matrix trace."""
123
124
def norm(a, ord=None, axis=None, keepdims=False) -> Array:
125
"""Matrix or vector norm."""
126
127
def cond(x, p=None) -> Array:
128
"""Condition number of matrix."""
129
130
# Matrix solutions
131
def solve(a, b, assume_a='gen', lower=False, overwrite_a=False,
132
overwrite_b=False, debug=None, check_finite=True) -> Array:
133
"""
134
Solve linear system Ax = b.
135
136
Args:
137
a: Coefficient matrix
138
b: Right-hand side vector/matrix
139
assume_a: Properties of matrix a ('gen', 'sym', 'her', 'pos')
140
lower: Whether to use lower triangle for triangular matrices
141
overwrite_a: Whether input can be overwritten
142
overwrite_b: Whether b can be overwritten
143
debug: Debug information level
144
check_finite: Whether to check for finite values
145
146
Returns:
147
Solution x such that Ax = b
148
"""
149
150
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
151
overwrite_b=False, debug=None, check_finite=True) -> Array:
152
"""Solve triangular linear system."""
153
154
def inv(a, overwrite_a=False, check_finite=True) -> Array:
155
"""Matrix inverse."""
156
157
def pinv(a, rcond=None, hermitian=False, return_rank=False) -> Array:
158
"""Moore-Penrose pseudoinverse."""
159
160
def lstsq(a, b, rcond=None, lapack_driver=None) -> tuple[Array, Array, Array, Array]:
161
"""
162
Least-squares solution to linear system.
163
164
Args:
165
a: Coefficient matrix
166
b: Dependent variable values
167
rcond: Cutoff ratio for small singular values
168
lapack_driver: LAPACK driver to use
169
170
Returns:
171
Tuple (solution, residuals, rank, singular_values)
172
"""
173
174
# Matrix functions
175
def expm(A) -> Array:
176
"""Matrix exponential."""
177
178
def funm(A, func, disp=True) -> Array:
179
"""General matrix function evaluation."""
180
181
def sqrtm(A, disp=True, blocksize=64) -> Array:
182
"""Matrix square root."""
183
184
def logm(A, disp=True) -> Array:
185
"""Matrix logarithm."""
186
187
def fractional_matrix_power(A, t) -> Array:
188
"""Fractional matrix power A^t."""
189
190
def matrix_power(A, n) -> Array:
191
"""Integer matrix power A^n."""
192
193
# Schur decomposition
194
def schur(a, output='real') -> tuple[Array, Array]:
195
"""Schur decomposition of matrix."""
196
197
def rsf2csf(T, Z) -> tuple[Array, Array]:
198
"""Convert real Schur form to complex Schur form."""
199
200
# Polar decomposition
201
def polar(a, side='right') -> tuple[Array, Array]:
202
"""Polar decomposition of matrix."""
203
```
204
205
### Special Functions (`jax.scipy.special`)
206
207
Special mathematical functions including error functions, gamma functions, and Bessel functions.
208
209
```python { .api }
210
# Error functions
211
def erf(z) -> Array:
212
"""Error function."""
213
214
def erfc(x) -> Array:
215
"""Complementary error function."""
216
217
def erfinv(y) -> Array:
218
"""Inverse error function."""
219
220
def erfcinv(y) -> Array:
221
"""Inverse complementary error function."""
222
223
def wofz(z) -> Array:
224
"""Faddeeva function."""
225
226
# Gamma functions
227
def gamma(z) -> Array:
228
"""Gamma function."""
229
230
def gammaln(x) -> Array:
231
"""Log gamma function."""
232
233
def digamma(x) -> Array:
234
"""Digamma (psi) function."""
235
236
def polygamma(n, x) -> Array:
237
"""Polygamma function."""
238
239
def gammainc(a, x) -> Array:
240
"""Lower incomplete gamma function."""
241
242
def gammaincc(a, x) -> Array:
243
"""Upper incomplete gamma function."""
244
245
def gammasgn(x) -> Array:
246
"""Sign of gamma function."""
247
248
def rgamma(x) -> Array:
249
"""Reciprocal gamma function."""
250
251
# Beta functions
252
def beta(a, b) -> Array:
253
"""Beta function."""
254
255
def betaln(a, b) -> Array:
256
"""Log beta function."""
257
258
def betainc(a, b, x) -> Array:
259
"""Incomplete beta function."""
260
261
# Bessel functions
262
def j0(x) -> Array:
263
"""Bessel function of the first kind of order 0."""
264
265
def j1(x) -> Array:
266
"""Bessel function of the first kind of order 1."""
267
268
def jn(n, x) -> Array:
269
"""Bessel function of the first kind of order n."""
270
271
def y0(x) -> Array:
272
"""Bessel function of the second kind of order 0."""
273
274
def y1(x) -> Array:
275
"""Bessel function of the second kind of order 1."""
276
277
def yn(n, x) -> Array:
278
"""Bessel function of the second kind of order n."""
279
280
def i0(x) -> Array:
281
"""Modified Bessel function of the first kind of order 0."""
282
283
def i0e(x) -> Array:
284
"""Exponentially scaled modified Bessel function i0."""
285
286
def i1(x) -> Array:
287
"""Modified Bessel function of the first kind of order 1."""
288
289
def i1e(x) -> Array:
290
"""Exponentially scaled modified Bessel function i1."""
291
292
def iv(v, z) -> Array:
293
"""Modified Bessel function of the first kind of real order."""
294
295
def k0(x) -> Array:
296
"""Modified Bessel function of the second kind of order 0."""
297
298
def k0e(x) -> Array:
299
"""Exponentially scaled modified Bessel function k0."""
300
301
def k1(x) -> Array:
302
"""Modified Bessel function of the second kind of order 1."""
303
304
def k1e(x) -> Array:
305
"""Exponentially scaled modified Bessel function k1."""
306
307
def kv(v, z) -> Array:
308
"""Modified Bessel function of the second kind of real order."""
309
310
# Exponential integrals
311
def expi(x) -> Array:
312
"""Exponential integral Ei."""
313
314
def expn(n, x) -> Array:
315
"""Generalized exponential integral."""
316
317
# Log-sum-exp and related
318
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False) -> Array:
319
"""
320
Compute log(sum(exp(a))) in numerically stable way.
321
322
Args:
323
a: Input array
324
axis: Axis to sum over
325
b: Multiplier for each element
326
keepdims: Whether to keep reduced dimensions
327
return_sign: Whether to return sign separately
328
329
Returns:
330
Log-sum-exp result
331
"""
332
333
def softmax(x, axis=None) -> Array:
334
"""Softmax function."""
335
336
def log_softmax(x, axis=None) -> Array:
337
"""Log softmax function."""
338
339
# Combinatorial functions
340
def factorial(n, exact=False) -> Array:
341
"""Factorial function."""
342
343
def factorial2(n, exact=False) -> Array:
344
"""Double factorial function."""
345
346
def factorialk(n, k, exact=False) -> Array:
347
"""Multifactorial function."""
348
349
def comb(N, k, exact=False, repetition=False) -> Array:
350
"""Binomial coefficient."""
351
352
def perm(N, k, exact=False) -> Array:
353
"""Permutation coefficient."""
354
355
# Elliptic integrals
356
def ellipk(m) -> Array:
357
"""Complete elliptic integral of the first kind."""
358
359
def ellipe(m) -> Array:
360
"""Complete elliptic integral of the second kind."""
361
362
def ellipkinc(phi, m) -> Array:
363
"""Incomplete elliptic integral of the first kind."""
364
365
def ellipeinc(phi, m) -> Array:
366
"""Incomplete elliptic integral of the second kind."""
367
368
# Zeta and related functions
369
def zeta(x, q=None) -> Array:
370
"""Riemann or Hurwitz zeta function."""
371
372
def zetac(x) -> Array:
373
"""Riemann zeta function minus 1."""
374
375
# Hypergeometric functions
376
def hyp1f1(a, b, x) -> Array:
377
"""Confluent hypergeometric function 1F1."""
378
379
def hyp2f1(a, b, c, z) -> Array:
380
"""Gaussian hypergeometric function 2F1."""
381
382
def hyperu(a, b, x) -> Array:
383
"""Confluent hypergeometric function U."""
384
385
# Legendre functions
386
def legendre(n, x) -> Array:
387
"""Legendre polynomial."""
388
389
def lpmv(m, v, x) -> Array:
390
"""Associated Legendre function."""
391
392
# Spherical functions
393
def sph_harm(m, n, theta, phi) -> Array:
394
"""Spherical harmonics."""
395
396
# Other special functions
397
def lambertw(z, k=0, tol=1e-8) -> Array:
398
"""Lambert W function."""
399
400
def spence(z) -> Array:
401
"""Spence function."""
402
403
def multigammaln(a, d) -> Array:
404
"""Log of multivariate gamma function."""
405
406
def entr(x) -> Array:
407
"""Elementwise function -x*log(x)."""
408
409
def kl_div(x, y) -> Array:
410
"""Elementwise function x*log(x/y) - x + y."""
411
412
def rel_entr(x, y) -> Array:
413
"""Elementwise function x*log(x/y)."""
414
415
def huber(delta, r) -> Array:
416
"""Huber loss function."""
417
418
def pseudo_huber(delta, r) -> Array:
419
"""Pseudo-Huber loss function."""
420
```
421
422
### Statistics (`jax.scipy.stats`)
423
424
Statistical distributions and functions for probability and hypothesis testing.
425
426
```python { .api }
427
# Continuous distributions
428
class norm:
429
"""Normal distribution."""
430
@staticmethod
431
def pdf(x, loc=0, scale=1) -> Array: ...
432
@staticmethod
433
def logpdf(x, loc=0, scale=1) -> Array: ...
434
@staticmethod
435
def cdf(x, loc=0, scale=1) -> Array: ...
436
@staticmethod
437
def logcdf(x, loc=0, scale=1) -> Array: ...
438
@staticmethod
439
def sf(x, loc=0, scale=1) -> Array: ...
440
@staticmethod
441
def logsf(x, loc=0, scale=1) -> Array: ...
442
@staticmethod
443
def ppf(q, loc=0, scale=1) -> Array: ...
444
@staticmethod
445
def isf(q, loc=0, scale=1) -> Array: ...
446
447
class multivariate_normal:
448
"""Multivariate normal distribution."""
449
@staticmethod
450
def pdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
451
@staticmethod
452
def logpdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
453
454
class uniform:
455
"""Uniform distribution."""
456
@staticmethod
457
def pdf(x, loc=0, scale=1) -> Array: ...
458
@staticmethod
459
def logpdf(x, loc=0, scale=1) -> Array: ...
460
@staticmethod
461
def cdf(x, loc=0, scale=1) -> Array: ...
462
@staticmethod
463
def logcdf(x, loc=0, scale=1) -> Array: ...
464
@staticmethod
465
def sf(x, loc=0, scale=1) -> Array: ...
466
@staticmethod
467
def logsf(x, loc=0, scale=1) -> Array: ...
468
@staticmethod
469
def ppf(q, loc=0, scale=1) -> Array: ...
470
471
class beta:
472
"""Beta distribution."""
473
@staticmethod
474
def pdf(x, a, b, loc=0, scale=1) -> Array: ...
475
@staticmethod
476
def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
477
@staticmethod
478
def cdf(x, a, b, loc=0, scale=1) -> Array: ...
479
480
class gamma:
481
"""Gamma distribution."""
482
@staticmethod
483
def pdf(x, a, loc=0, scale=1) -> Array: ...
484
@staticmethod
485
def logpdf(x, a, loc=0, scale=1) -> Array: ...
486
@staticmethod
487
def cdf(x, a, loc=0, scale=1) -> Array: ...
488
489
class chi2:
490
"""Chi-square distribution."""
491
@staticmethod
492
def pdf(x, df, loc=0, scale=1) -> Array: ...
493
@staticmethod
494
def logpdf(x, df, loc=0, scale=1) -> Array: ...
495
@staticmethod
496
def cdf(x, df, loc=0, scale=1) -> Array: ...
497
498
class t:
499
"""Student's t-distribution."""
500
@staticmethod
501
def pdf(x, df, loc=0, scale=1) -> Array: ...
502
@staticmethod
503
def logpdf(x, df, loc=0, scale=1) -> Array: ...
504
@staticmethod
505
def cdf(x, df, loc=0, scale=1) -> Array: ...
506
507
class f:
508
"""F-distribution."""
509
@staticmethod
510
def pdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
511
@staticmethod
512
def logpdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
513
@staticmethod
514
def cdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
515
516
class laplace:
517
"""Laplace distribution."""
518
@staticmethod
519
def pdf(x, loc=0, scale=1) -> Array: ...
520
@staticmethod
521
def logpdf(x, loc=0, scale=1) -> Array: ...
522
@staticmethod
523
def cdf(x, loc=0, scale=1) -> Array: ...
524
525
class logistic:
526
"""Logistic distribution."""
527
@staticmethod
528
def pdf(x, loc=0, scale=1) -> Array: ...
529
@staticmethod
530
def logpdf(x, loc=0, scale=1) -> Array: ...
531
@staticmethod
532
def cdf(x, loc=0, scale=1) -> Array: ...
533
534
class pareto:
535
"""Pareto distribution."""
536
@staticmethod
537
def pdf(x, b, loc=0, scale=1) -> Array: ...
538
@staticmethod
539
def logpdf(x, b, loc=0, scale=1) -> Array: ...
540
@staticmethod
541
def cdf(x, b, loc=0, scale=1) -> Array: ...
542
543
class expon:
544
"""Exponential distribution."""
545
@staticmethod
546
def pdf(x, loc=0, scale=1) -> Array: ...
547
@staticmethod
548
def logpdf(x, loc=0, scale=1) -> Array: ...
549
@staticmethod
550
def cdf(x, loc=0, scale=1) -> Array: ...
551
552
class lognorm:
553
"""Log-normal distribution."""
554
@staticmethod
555
def pdf(x, s, loc=0, scale=1) -> Array: ...
556
@staticmethod
557
def logpdf(x, s, loc=0, scale=1) -> Array: ...
558
@staticmethod
559
def cdf(x, s, loc=0, scale=1) -> Array: ...
560
561
class truncnorm:
562
"""Truncated normal distribution."""
563
@staticmethod
564
def pdf(x, a, b, loc=0, scale=1) -> Array: ...
565
@staticmethod
566
def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
567
@staticmethod
568
def cdf(x, a, b, loc=0, scale=1) -> Array: ...
569
570
# Discrete distributions
571
class bernoulli:
572
"""Bernoulli distribution."""
573
@staticmethod
574
def pmf(k, p, loc=0) -> Array: ...
575
@staticmethod
576
def logpmf(k, p, loc=0) -> Array: ...
577
@staticmethod
578
def cdf(k, p, loc=0) -> Array: ...
579
580
class binom:
581
"""Binomial distribution."""
582
@staticmethod
583
def pmf(k, n, p, loc=0) -> Array: ...
584
@staticmethod
585
def logpmf(k, n, p, loc=0) -> Array: ...
586
@staticmethod
587
def cdf(k, n, p, loc=0) -> Array: ...
588
589
class geom:
590
"""Geometric distribution."""
591
@staticmethod
592
def pmf(k, p, loc=0) -> Array: ...
593
@staticmethod
594
def logpmf(k, p, loc=0) -> Array: ...
595
@staticmethod
596
def cdf(k, p, loc=0) -> Array: ...
597
598
class nbinom:
599
"""Negative binomial distribution."""
600
@staticmethod
601
def pmf(k, n, p, loc=0) -> Array: ...
602
@staticmethod
603
def logpmf(k, n, p, loc=0) -> Array: ...
604
@staticmethod
605
def cdf(k, n, p, loc=0) -> Array: ...
606
607
class poisson:
608
"""Poisson distribution."""
609
@staticmethod
610
def pmf(k, mu, loc=0) -> Array: ...
611
@staticmethod
612
def logpmf(k, mu, loc=0) -> Array: ...
613
@staticmethod
614
def cdf(k, mu, loc=0) -> Array: ...
615
616
# Statistical functions
617
def mode(a, axis=0, nan_policy='propagate', keepdims=False) -> Array:
618
"""Mode of array values along axis."""
619
620
def rankdata(a, method='average', axis=None) -> Array:
621
"""Rank data along axis."""
622
623
def kendalltau(x, y, initial_lexsort=None, nan_policy='propagate', method='auto') -> tuple[Array, Array]:
624
"""Kendall's tau correlation coefficient."""
625
626
def pearsonr(x, y) -> tuple[Array, Array]:
627
"""Pearson correlation coefficient."""
628
629
def spearmanr(a, b=None, axis=0, nan_policy='propagate', alternative='two-sided') -> tuple[Array, Array]:
630
"""Spearman correlation coefficient."""
631
```
632
633
### Signal Processing (`jax.scipy.signal`)
634
635
Signal processing functions for filtering, convolution, and spectral analysis.
636
637
```python { .api }
638
def convolve(in1, in2, mode='full', method='auto') -> Array:
639
"""N-dimensional convolution."""
640
641
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
642
"""2D convolution."""
643
644
def correlate(in1, in2, mode='full', method='auto') -> Array:
645
"""Cross-correlation of two arrays."""
646
647
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
648
"""2D cross-correlation."""
649
650
def fftconvolve(in1, in2, mode='full', axes=None) -> Array:
651
"""FFT-based convolution."""
652
653
def oaconvolve(in1, in2, mode='full', axes=None) -> Array:
654
"""Overlap-add convolution."""
655
656
def lfilter(b, a, x, axis=-1, zi=None) -> Array:
657
"""Linear digital filter."""
658
659
def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None) -> Array:
660
"""Zero-phase digital filtering."""
661
662
def sosfilt(sos, x, axis=-1, zi=None) -> Array:
663
"""Filter using second-order sections."""
664
665
def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None) -> Array:
666
"""Zero-phase filtering with second-order sections."""
667
668
def hilbert(x, N=None, axis=-1) -> Array:
669
"""Hilbert transform."""
670
671
def hilbert2(x, N=None) -> Array:
672
"""2D Hilbert transform."""
673
674
def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True) -> Array:
675
"""Downsample signal by integer factor."""
676
677
def resample(x, num, t=None, axis=0, window=None, domain='time') -> Array:
678
"""Resample signal to new sample rate."""
679
680
def resample_poly(x, up, down, axis=0, window='kaiser', padtype='constant', cval=None) -> Array:
681
"""Resample using polyphase filtering."""
682
683
def upfirdn(h, x, up=1, down=1, axis=-1, mode='constant', cval=0) -> Array:
684
"""Upsample, FIR filter, and downsample."""
685
686
def periodogram(x, fs=1.0, window='boxcar', nfft=None, detrend='constant',
687
return_onesided=True, scaling='density', axis=-1) -> tuple[Array, Array]:
688
"""Periodogram power spectral density."""
689
690
def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
691
detrend='constant', return_onesided=True, scaling='density', axis=-1,
692
average='mean') -> tuple[Array, Array]:
693
"""Welch's method for power spectral density."""
694
695
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
696
detrend='constant', return_onesided=True, scaling='density', axis=-1,
697
average='mean') -> tuple[Array, Array]:
698
"""Cross power spectral density."""
699
700
def coherence(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
701
detrend='constant', axis=-1) -> tuple[Array, Array]:
702
"""Coherence between signals."""
703
704
def spectrogram(x, fs=1.0, window='tukey', nperseg=None, noverlap=None, nfft=None,
705
detrend='constant', return_onesided=True, scaling='density', axis=-1,
706
mode='psd') -> tuple[Array, Array, Array]:
707
"""Spectrogram using short-time Fourier transform."""
708
709
def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
710
detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1) -> tuple[Array, Array, Array]:
711
"""Short-time Fourier transform."""
712
713
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
714
input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2) -> tuple[Array, Array]:
715
"""Inverse short-time Fourier transform."""
716
717
def lombscargle(x, y, freqs, precenter=False, normalize=False) -> Array:
718
"""Lomb-Scargle periodogram."""
719
720
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False) -> Array:
721
"""Remove linear trend from data."""
722
723
def find_peaks(x, height=None, threshold=None, distance=None, prominence=None,
724
width=None, wlen=None, rel_height=0.5, plateau_size=None) -> tuple[Array, dict]:
725
"""Find peaks in 1D array."""
726
727
def peak_prominences(x, peaks, wlen=None) -> tuple[Array, Array, Array]:
728
"""Calculate peak prominences."""
729
730
def peak_widths(x, peaks, rel_height=0.5, prominence_data=None, wlen=None) -> tuple[Array, Array, Array, Array]:
731
"""Calculate peak widths."""
732
```
733
734
### Other Submodules
735
736
```python { .api }
737
# Fast Fourier Transform (jax.scipy.fft)
738
import jax.scipy.fft as jfft
739
# Same interface as jax.numpy.fft with additional functions
740
741
# N-dimensional image processing (jax.scipy.ndimage)
742
import jax.scipy.ndimage as jnd
743
# Image filtering, morphology, and measurements
744
745
# Sparse matrix operations (jax.scipy.sparse)
746
import jax.scipy.sparse as jss
747
# Sparse matrix formats and operations
748
749
# Interpolation (jax.scipy.interpolate)
750
import jax.scipy.interpolate as jsi
751
# 1D and multidimensional interpolation
752
753
# Clustering (jax.scipy.cluster)
754
import jax.scipy.cluster as jsc
755
# Hierarchical and k-means clustering
756
757
# Integration and ODE solving (jax.scipy.integrate)
758
import jax.scipy.integrate as jsi
759
# Numerical integration and differential equation solving
760
```
761
762
## Usage Examples
763
764
```python
765
import jax.numpy as jnp
766
import jax.scipy as jsp
767
import jax.scipy.linalg as jla
768
import jax.scipy.special as jss
769
import jax.scipy.stats as jst
770
771
# Linear algebra example
772
A = jnp.array([[4.0, 2.0], [2.0, 3.0]])
773
b = jnp.array([1.0, 2.0])
774
775
# Solve linear system
776
x = jla.solve(A, b)
777
778
# Compute eigenvalues and eigenvectors
779
eigenvals, eigenvecs = jla.eigh(A)
780
781
# Matrix decomposition
782
L = jla.cholesky(A) # A = L @ L.T
783
784
# Special functions
785
x = jnp.linspace(-3, 3, 100)
786
erf_vals = jss.erf(x)
787
gamma_vals = jss.gamma(x + 1)
788
789
# Statistical distributions
790
data = jnp.array([1.2, 2.3, 1.8, 3.1, 2.7])
791
log_likelihood = jst.norm.logpdf(data, loc=2.0, scale=1.0).sum()
792
793
# Probability density functions
794
x_vals = jnp.linspace(0, 5, 100)
795
pdf_vals = jst.gamma.pdf(x_vals, a=2.0, scale=1.0)
796
797
# Use in optimization with JAX transformations
798
@jax.jit
799
def neg_log_likelihood(params, data):
800
mu, sigma = params
801
return -jst.norm.logpdf(data, mu, sigma).sum()
802
803
# Compute gradient for maximum likelihood estimation
804
grad_fn = jax.grad(neg_log_likelihood)
805
gradients = grad_fn([2.0, 1.0], data)
806
```