0
# Stationary Wavelet Transform
1
2
Stationary (undecimated) wavelet transforms providing translation-invariant analysis with no downsampling, preserving all frequency information at each decomposition level.
3
4
## Capabilities
5
6
### 1D Stationary Wavelet Transform
7
8
Translation-invariant decomposition and reconstruction for one-dimensional signals.
9
10
```python { .api }
11
def swt(data, wavelet, level: int = None, start_level: int = 0, axis: int = -1,
12
trim_approx: bool = False, norm: bool = False):
13
"""
14
1D stationary wavelet transform.
15
16
Parameters:
17
- data: Input 1D array (length must be divisible by 2^level)
18
- wavelet: Wavelet specification
19
- level: Number of decomposition levels (default: maximum possible)
20
- start_level: Starting level (default: 0)
21
- axis: Axis along which to perform SWT (default: -1)
22
- trim_approx: If True, return only detail coefficients
23
- norm: If True, normalize coefficients
24
25
Returns:
26
List of (cA, cD) tuples for each level if trim_approx=False
27
List of cD arrays for each level if trim_approx=True
28
"""
29
30
def iswt(coeffs, wavelet, norm: bool = False, axis: int = -1):
31
"""
32
1D inverse stationary wavelet transform.
33
34
Parameters:
35
- coeffs: List of (cA, cD) tuples from swt or list of cD arrays
36
- wavelet: Wavelet specification matching forward transform
37
- norm: Normalization flag matching forward transform
38
- axis: Axis along which to perform ISWT
39
40
Returns:
41
Reconstructed 1D signal
42
"""
43
```
44
45
#### Usage Examples
46
47
```python
48
import pywt
49
import numpy as np
50
import matplotlib.pyplot as plt
51
52
# Create test signal (length must be divisible by 2^level)
53
n = 1024 # 2^10
54
t = np.linspace(0, 1, n)
55
signal = (np.sin(2 * np.pi * 5 * t) + # Low frequency
56
0.5 * np.sin(2 * np.pi * 20 * t) + # Medium frequency
57
0.3 * np.cos(2 * np.pi * 100 * t)) # High frequency
58
noise = 0.2 * np.random.randn(n)
59
noisy_signal = signal + noise
60
61
# Stationary wavelet transform
62
level = 6
63
coeffs = pywt.swt(noisy_signal, 'db4', level=level)
64
print(f"Number of decomposition levels: {len(coeffs)}")
65
print(f"Each coefficient array length: {len(coeffs[0][0])}") # Same as input length
66
67
# Access coefficients
68
for i, (cA, cD) in enumerate(coeffs):
69
print(f"Level {i+1}: Approximation shape {cA.shape}, Detail shape {cD.shape}")
70
71
# Perfect reconstruction
72
reconstructed = pywt.iswt(coeffs, 'db4')
73
print(f"Reconstruction error: {np.max(np.abs(noisy_signal - reconstructed))}")
74
75
# SWT preserves signal length at all levels - good for analysis
76
# Compare with regular DWT
77
dwt_coeffs = pywt.wavedec(noisy_signal, 'db4', level=level)
78
print(f"DWT coefficient lengths: {[len(c) for c in dwt_coeffs]}")
79
print(f"SWT coefficient lengths: {[len(coeffs[i][0]) for i in range(level)]}")
80
81
# Translation invariance demonstration
82
shifted_signal = np.roll(noisy_signal, 100) # Circular shift
83
coeffs_shifted = pywt.swt(shifted_signal, 'db4', level=level)
84
85
# Compare detail coefficients at level 3
86
detail_orig = coeffs[2][1] # Detail at level 3
87
detail_shifted = coeffs_shifted[2][1] # Detail at level 3 for shifted signal
88
89
# SWT coefficients are also shifted (translation invariance)
90
correlation = np.correlate(detail_orig, detail_shifted, mode='full')
91
shift_detected = np.argmax(correlation) - len(detail_orig) + 1
92
print(f"Detected shift in coefficients: {shift_detected % n}")
93
94
# Visualization
95
fig, axes = plt.subplots(3, 2, figsize=(15, 12))
96
97
axes[0, 0].plot(t, signal, 'b-', label='Clean')
98
axes[0, 0].plot(t, noisy_signal, 'r-', alpha=0.7, label='Noisy')
99
axes[0, 0].set_title('Original vs Noisy Signal')
100
axes[0, 0].legend()
101
102
axes[0, 1].plot(t, reconstructed, 'g-', label='Reconstructed')
103
axes[0, 1].plot(t, noisy_signal, 'r--', alpha=0.7, label='Original')
104
axes[0, 1].set_title('SWT Reconstruction')
105
axes[0, 1].legend()
106
107
# Plot details at different levels
108
for i in range(4):
109
row = (i // 2) + 1
110
col = i % 2
111
if i < len(coeffs):
112
cA, cD = coeffs[i]
113
axes[row, col].plot(t, cD)
114
axes[row, col].set_title(f'Detail Level {i+1}')
115
116
plt.tight_layout()
117
plt.show()
118
```
119
120
### 2D Stationary Wavelet Transform
121
122
Translation-invariant decomposition and reconstruction for two-dimensional data.
123
124
```python { .api }
125
def swt2(data, wavelet, level: int, start_level: int = 0, axes=(-2, -1),
126
trim_approx: bool = False, norm: bool = False):
127
"""
128
2D stationary wavelet transform.
129
130
Parameters:
131
- data: Input 2D array (dimensions must be divisible by 2^level)
132
- wavelet: Wavelet specification
133
- level: Number of decomposition levels (required)
134
- start_level: Starting level (default: 0)
135
- axes: Pair of axes for 2D transform (default: last two axes)
136
- trim_approx: If True, return only detail coefficients
137
- norm: If True, normalize coefficients
138
139
Returns:
140
List of (cA, (cH, cV, cD)) tuples for each level if trim_approx=False
141
List of (cH, cV, cD) tuples for each level if trim_approx=True
142
"""
143
144
def iswt2(coeffs, wavelet, norm: bool = False, axes=(-2, -1)):
145
"""
146
2D inverse stationary wavelet transform.
147
148
Parameters:
149
- coeffs: List of coefficient tuples from swt2
150
- wavelet: Wavelet specification matching forward transform
151
- norm: Normalization flag matching forward transform
152
- axes: Pair of axes for 2D transform
153
154
Returns:
155
Reconstructed 2D array
156
"""
157
```
158
159
#### Usage Examples
160
161
```python
162
import pywt
163
import numpy as np
164
import matplotlib.pyplot as plt
165
166
# Create test image (dimensions must be divisible by 2^level)
167
size = 256 # 2^8
168
x, y = np.mgrid[0:size, 0:size]
169
image = (np.sin(2 * np.pi * x / 64) * np.cos(2 * np.pi * y / 64) +
170
0.5 * np.sin(2 * np.pi * x / 16) * np.cos(2 * np.pi * y / 16))
171
noise = 0.3 * np.random.randn(size, size)
172
noisy_image = image + noise
173
174
print(f"Image shape: {noisy_image.shape}")
175
176
# 2D stationary wavelet transform
177
level = 4
178
coeffs = pywt.swt2(noisy_image, 'db2', level=level)
179
print(f"Number of decomposition levels: {len(coeffs)}")
180
181
# Each level preserves the original image size
182
for i, (cA, (cH, cV, cD)) in enumerate(coeffs):
183
print(f"Level {i+1}: All coefficients shape {cA.shape}")
184
185
# Perfect reconstruction
186
reconstructed = pywt.iswt2(coeffs, 'db2')
187
print(f"2D SWT reconstruction error: {np.max(np.abs(noisy_image - reconstructed))}")
188
189
# Translation invariance in 2D
190
shifted_image = np.roll(np.roll(noisy_image, 50, axis=0), 30, axis=1)
191
coeffs_shifted = pywt.swt2(shifted_image, 'db2', level=level)
192
193
# Image denoising using SWT
194
def swt_denoise(image, wavelet, level, threshold):
195
"""Denoise image using stationary wavelet transform."""
196
coeffs = pywt.swt2(image, wavelet, level=level)
197
198
# Threshold detail coefficients at all levels
199
coeffs_thresh = []
200
for cA, (cH, cV, cD) in coeffs:
201
cH_thresh = pywt.threshold(cH, threshold, mode='soft')
202
cV_thresh = pywt.threshold(cV, threshold, mode='soft')
203
cD_thresh = pywt.threshold(cD, threshold, mode='soft')
204
coeffs_thresh.append((cA, (cH_thresh, cV_thresh, cD_thresh)))
205
206
return pywt.iswt2(coeffs_thresh, wavelet)
207
208
# Apply denoising
209
denoised_image = swt_denoise(noisy_image, 'db4', level=3, threshold=0.1)
210
211
# Visualization
212
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
213
214
axes[0, 0].imshow(image, cmap='gray')
215
axes[0, 0].set_title('Original Clean Image')
216
axes[0, 0].axis('off')
217
218
axes[0, 1].imshow(noisy_image, cmap='gray')
219
axes[0, 1].set_title('Noisy Image')
220
axes[0, 1].axis('off')
221
222
axes[0, 2].imshow(denoised_image, cmap='gray')
223
axes[0, 2].set_title('SWT Denoised')
224
axes[0, 2].axis('off')
225
226
# Show coefficients at finest level
227
cA_1, (cH_1, cV_1, cD_1) = coeffs[-1] # Level 1 (finest)
228
axes[1, 0].imshow(np.abs(cH_1), cmap='gray')
229
axes[1, 0].set_title('Horizontal Details (Level 1)')
230
axes[1, 0].axis('off')
231
232
axes[1, 1].imshow(np.abs(cV_1), cmap='gray')
233
axes[1, 1].set_title('Vertical Details (Level 1)')
234
axes[1, 1].axis('off')
235
236
axes[1, 2].imshow(np.abs(cD_1), cmap='gray')
237
axes[1, 2].set_title('Diagonal Details (Level 1)')
238
axes[1, 2].axis('off')
239
240
plt.tight_layout()
241
plt.show()
242
243
# Compare with regular DWT for edge preservation
244
regular_dwt = pywt.wavedec2(noisy_image, 'db4', level=3)
245
regular_dwt_thresh = [regular_dwt[0]] # Keep approximation
246
for cH, cV, cD in regular_dwt[1:]:
247
cH_t = pywt.threshold(cH, 0.1, mode='soft')
248
cV_t = pywt.threshold(cV, 0.1, mode='soft')
249
cD_t = pywt.threshold(cD, 0.1, mode='soft')
250
regular_dwt_thresh.append((cH_t, cV_t, cD_t))
251
252
regular_denoised = pywt.waverec2(regular_dwt_thresh, 'db4')
253
254
# Edge detection to compare preservation
255
def edge_strength(img):
256
"""Simple edge strength measure."""
257
gx = np.diff(img, axis=1)
258
gy = np.diff(img, axis=0)
259
return np.mean(np.sqrt(gx[:-1,:]**2 + gy[:,:-1]**2))
260
261
print(f"Edge strength - Original: {edge_strength(image):.4f}")
262
print(f"Edge strength - SWT denoised: {edge_strength(denoised_image):.4f}")
263
print(f"Edge strength - DWT denoised: {edge_strength(regular_denoised):.4f}")
264
```
265
266
### nD Stationary Wavelet Transform
267
268
Translation-invariant decomposition and reconstruction for n-dimensional data.
269
270
```python { .api }
271
def swtn(data, wavelet, level: int, start_level: int = 0, axes=None,
272
trim_approx: bool = False, norm: bool = False):
273
"""
274
nD stationary wavelet transform.
275
276
Parameters:
277
- data: Input nD array (all dimensions must be divisible by 2^level)
278
- wavelet: Wavelet specification
279
- level: Number of decomposition levels (required)
280
- start_level: Starting level (default: 0)
281
- axes: Axes along which to perform transform (default: all axes)
282
- trim_approx: If True, return only detail coefficients
283
- norm: If True, normalize coefficients
284
285
Returns:
286
List of coefficient dictionaries for each level
287
"""
288
289
def iswtn(coeffs, wavelet, axes=None, norm: bool = False):
290
"""
291
nD inverse stationary wavelet transform.
292
293
Parameters:
294
- coeffs: List of coefficient dictionaries from swtn
295
- wavelet: Wavelet specification matching forward transform
296
- axes: Axes along which to perform transform (should match forward)
297
- norm: Normalization flag matching forward transform
298
299
Returns:
300
Reconstructed nD array
301
"""
302
```
303
304
### Utility Functions
305
306
```python { .api }
307
def swt_max_level(input_len: int) -> int:
308
"""
309
Compute maximum SWT decomposition level.
310
311
Parameters:
312
- input_len: Length of input signal
313
314
Returns:
315
Maximum level such that input_len is divisible by 2^level
316
"""
317
```
318
319
#### Usage Examples
320
321
```python
322
import pywt
323
import numpy as np
324
325
# 3D volume processing with SWT
326
volume = np.random.randn(64, 64, 64) # Must be powers of 2
327
print(f"3D volume shape: {volume.shape}")
328
329
# Check maximum level
330
max_level = pywt.swt_max_level(64) # 64 = 2^6, so max level is 6
331
print(f"Maximum SWT level for size 64: {max_level}")
332
333
# 3D SWT
334
level = 3
335
coeffs_3d = pywt.swtn(volume, 'haar', level=level)
336
print(f"Number of 3D SWT levels: {len(coeffs_3d)}")
337
338
# Each level has same size as input
339
for i, coeff_dict in enumerate(coeffs_3d):
340
print(f"Level {i+1} coefficients:")
341
for key, coeff in coeff_dict.items():
342
print(f" '{key}': {coeff.shape}")
343
344
# Perfect reconstruction
345
reconstructed_3d = pywt.iswtn(coeffs_3d, 'haar')
346
print(f"3D SWT reconstruction error: {np.max(np.abs(volume - reconstructed_3d))}")
347
348
# Example: 1D signal length analysis
349
for length in [128, 256, 512, 1000]:
350
max_level = pywt.swt_max_level(length)
351
print(f"Length {length}: max SWT level = {max_level}")
352
353
# SWT with trimmed approximation (details only)
354
signal_1d = np.random.randn(512)
355
details_only = pywt.swt(signal_1d, 'db2', level=4, trim_approx=True)
356
print(f"Details only - number of arrays: {len(details_only)}")
357
print(f"Each detail array shape: {details_only[0].shape}")
358
359
# Reconstruction from details only requires adding zero approximation
360
zero_approx = np.zeros_like(details_only[0])
361
coeffs_for_recon = [(zero_approx, detail) for detail in details_only]
362
reconstructed_details = pywt.iswt(coeffs_for_recon, 'db2')
363
print(f"Reconstruction from details only shape: {reconstructed_details.shape}")
364
```
365
366
## Types
367
368
```python { .api }
369
# SWT coefficient formats
370
SWTCoeffs1D = List[Tuple[np.ndarray, np.ndarray]] # List of (cA, cD) tuples
371
SWTCoeffs2D = List[Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]] # List of (cA, (cH, cV, cD))
372
SWTCoeffsND = List[Dict[str, np.ndarray]] # List of coefficient dictionaries
373
374
# Trimmed approximation formats (details only)
375
SWTDetails1D = List[np.ndarray] # List of detail arrays
376
SWTDetails2D = List[Tuple[np.ndarray, np.ndarray, np.ndarray]] # List of (cH, cV, cD) tuples
377
SWTDetailsND = List[Dict[str, np.ndarray]] # List of detail dictionaries (excluding approximation key)
378
```