0
# Wavelet Packet Transform
1
2
Wavelet packet decomposition providing complete binary tree analysis for detailed frequency localization and adaptive basis selection with both 1D and 2D implementations.
3
4
## Capabilities
5
6
### 1D Wavelet Packet Transform
7
8
Complete binary tree decomposition for 1D signals with node-based navigation.
9
10
```python { .api }
11
class WaveletPacket:
12
def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None):
13
"""
14
1D wavelet packet tree.
15
16
Parameters:
17
- data: Input 1D signal
18
- wavelet: Wavelet specification
19
- mode: Signal extension mode
20
- maxlevel: Maximum decomposition level (default: automatic)
21
"""
22
23
# Properties
24
data: np.ndarray # Node data
25
path: str # Node path ('' for root, 'a' for approximation child, 'd' for detail child)
26
level: int # Decomposition level
27
maxlevel: int # Maximum allowed level
28
mode: str # Signal extension mode
29
wavelet: Wavelet # Wavelet object
30
31
def decompose(self):
32
"""Decompose current node into approximation and detail children."""
33
34
def reconstruct(self, update: bool = False):
35
"""
36
Reconstruct node data from child nodes.
37
38
Parameters:
39
- update: If True, update node data with reconstructed values
40
41
Returns:
42
Reconstructed data array
43
"""
44
45
def __getitem__(self, path: str):
46
"""
47
Get node at specified path.
48
49
Parameters:
50
- path: Path string ('', 'a', 'd', 'aa', 'ad', 'da', 'dd', etc.)
51
52
Returns:
53
Node at specified path
54
"""
55
56
def __setitem__(self, path: str, data):
57
"""Set data for node at specified path."""
58
59
def get_level(self, level: int, order: str = 'natural', decompose: bool = True):
60
"""
61
Get all nodes at specified level.
62
63
Parameters:
64
- level: Decomposition level
65
- order: Node ordering ('natural' or 'freq')
66
- decompose: Whether to decompose nodes if needed
67
68
Returns:
69
List of nodes at specified level
70
"""
71
```
72
73
#### Usage Examples
74
75
```python
76
import pywt
77
import numpy as np
78
import matplotlib.pyplot as plt
79
80
# Create test signal with multiple frequency components
81
np.random.seed(42)
82
t = np.linspace(0, 1, 1024)
83
signal = (np.sin(2 * np.pi * 8 * t) + # 8 Hz component
84
0.7 * np.sin(2 * np.pi * 32 * t) + # 32 Hz component
85
0.5 * np.sin(2 * np.pi * 64 * t) + # 64 Hz component
86
0.2 * np.random.randn(len(t))) # Noise
87
88
print(f"Signal length: {len(signal)}")
89
90
# Create wavelet packet tree
91
wp = pywt.WaveletPacket(signal, 'db4', mode='symmetric', maxlevel=6)
92
print(f"Wavelet packet created with maxlevel: {wp.maxlevel}")
93
print(f"Root node path: '{wp.path}', level: {wp.level}")
94
95
# Navigate the tree structure
96
print(f"\nTree navigation:")
97
print(f"Root data shape: {wp.data.shape}")
98
print(f"Root node: '{wp.path}' -> {wp.data.shape}")
99
100
# Access child nodes (automatically decomposes if needed)
101
approx_child = wp['a'] # Approximation child
102
detail_child = wp['d'] # Detail child
103
print(f"Approximation child: '{approx_child.path}' -> {approx_child.data.shape}")
104
print(f"Detail child: '{detail_child.path}' -> {detail_child.data.shape}")
105
106
# Go deeper into the tree
107
aa_node = wp['aa'] # Approximation of approximation
108
ad_node = wp['ad'] # Detail of approximation
109
da_node = wp['da'] # Approximation of detail
110
dd_node = wp['dd'] # Detail of detail
111
112
print(f"Level 2 nodes:")
113
for path in ['aa', 'ad', 'da', 'dd']:
114
node = wp[path]
115
print(f" '{path}': level {node.level}, shape {node.data.shape}")
116
117
# Get all nodes at a specific level
118
level_3_nodes = wp.get_level(3, order='natural')
119
print(f"\nLevel 3 nodes: {len(level_3_nodes)}")
120
for node in level_3_nodes:
121
print(f" Path: '{node.path}', shape: {node.data.shape}")
122
123
# Frequency ordering
124
level_3_freq = wp.get_level(3, order='freq')
125
print(f"Level 3 (frequency order): {[node.path for node in level_3_freq]}")
126
127
# Reconstruct signal from packet decomposition
128
reconstructed = wp.reconstruct()
129
reconstruction_error = np.max(np.abs(signal - reconstructed))
130
print(f"\nReconstruction error: {reconstruction_error:.2e}")
131
132
# Best basis selection (energy-based)
133
def calculate_node_energy(node):
134
"""Calculate energy of a node."""
135
return np.sum(node.data**2)
136
137
def find_best_basis_energy(wp, max_level):
138
"""Find best basis based on energy concentration."""
139
best_basis = []
140
141
def traverse_node(node):
142
if node.level >= max_level:
143
best_basis.append(node.path)
144
return calculate_node_energy(node)
145
146
# Decompose and calculate energy for children
147
node.decompose()
148
approx_child = wp[node.path + 'a']
149
detail_child = wp[node.path + 'd']
150
151
child_energy = traverse_node(approx_child) + traverse_node(detail_child)
152
current_energy = calculate_node_energy(node)
153
154
if current_energy >= child_energy:
155
# Current node has better energy concentration
156
best_basis = [path for path in best_basis if not path.startswith(node.path + 'a') and not path.startswith(node.path + 'd')]
157
best_basis.append(node.path)
158
return current_energy
159
else:
160
return child_energy
161
162
traverse_node(wp)
163
return best_basis
164
165
best_basis = find_best_basis_energy(wp, 4)
166
print(f"\nBest basis (energy): {best_basis}")
167
168
# Visualize wavelet packet decomposition
169
fig, axes = plt.subplots(4, 1, figsize=(15, 12))
170
171
# Original signal
172
axes[0].plot(t, signal)
173
axes[0].set_title('Original Signal')
174
axes[0].set_ylabel('Amplitude')
175
axes[0].grid(True)
176
177
# Level 2 decomposition
178
level_2_nodes = wp.get_level(2)
179
for i, node in enumerate(level_2_nodes):
180
if i < 3: # Show first 3 nodes
181
axes[i+1].plot(node.data)
182
axes[i+1].set_title(f"Node '{node.path}' (Level {node.level})")
183
axes[i+1].set_ylabel('Amplitude')
184
axes[i+1].grid(True)
185
186
plt.tight_layout()
187
plt.show()
188
189
# Packet-based denoising
190
def packet_denoise(wp, threshold_factor=0.1):
191
"""Denoise using wavelet packet thresholding."""
192
# Get all leaf nodes at maximum decomposition
193
leaf_nodes = wp.get_level(wp.maxlevel)
194
195
for node in leaf_nodes:
196
if node.path != '': # Don't threshold root
197
# Calculate adaptive threshold based on node statistics
198
threshold = threshold_factor * np.std(node.data)
199
node.data = pywt.threshold(node.data, threshold, mode='soft')
200
201
return wp.reconstruct()
202
203
# Create noisy version and denoise
204
noisy_signal = signal + 0.4 * np.random.randn(len(signal))
205
wp_noisy = pywt.WaveletPacket(noisy_signal, 'db4', maxlevel=5)
206
denoised_signal = packet_denoise(wp_noisy, threshold_factor=0.2)
207
208
# Compare denoising methods
209
regular_coeffs = pywt.wavedec(noisy_signal, 'db4', level=5)
210
regular_thresh = [regular_coeffs[0]] # Keep approximation
211
for detail in regular_coeffs[1:]:
212
threshold = 0.2 * np.std(detail)
213
regular_thresh.append(pywt.threshold(detail, threshold, mode='soft'))
214
regular_denoised = pywt.waverec(regular_thresh, 'db4')
215
216
print(f"\nDenoising comparison:")
217
packet_snr = 10 * np.log10(np.var(signal) / np.var(signal - denoised_signal))
218
regular_snr = 10 * np.log10(np.var(signal) / np.var(signal - regular_denoised))
219
print(f"Wavelet packet denoising SNR: {packet_snr:.2f} dB")
220
print(f"Regular wavelet denoising SNR: {regular_snr:.2f} dB")
221
```
222
223
### 2D Wavelet Packet Transform
224
225
Complete binary tree decomposition for 2D data such as images.
226
227
```python { .api }
228
class WaveletPacket2D:
229
def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None):
230
"""
231
2D wavelet packet tree.
232
233
Parameters:
234
- data: Input 2D array
235
- wavelet: Wavelet specification
236
- mode: Signal extension mode
237
- maxlevel: Maximum decomposition level
238
"""
239
240
# Similar interface to WaveletPacket but for 2D data
241
# Path strings use combinations of 'a' and 'd' for each dimension
242
# e.g., 'aa', 'ad', 'da', 'dd' for level 1
243
# 'aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd' for level 2
244
```
245
246
#### Usage Examples
247
248
```python
249
import pywt
250
import numpy as np
251
import matplotlib.pyplot as plt
252
253
# Create test image
254
size = 256
255
x, y = np.mgrid[0:size, 0:size]
256
image = np.zeros((size, size))
257
258
# Add geometric patterns
259
image[64:192, 64:192] = 1.0
260
image[96:160, 96:160] = 0.5
261
# Add some texture
262
image += 0.1 * np.sin(2 * np.pi * x / 32) * np.cos(2 * np.pi * y / 32)
263
264
print(f"Image shape: {image.shape}")
265
266
# Create 2D wavelet packet
267
wp2d = pywt.WaveletPacket2D(image, 'db2', maxlevel=4)
268
print(f"2D Wavelet packet maxlevel: {wp2d.maxlevel}")
269
270
# Access 2D packet nodes
271
print(f"\n2D packet navigation:")
272
aa_node = wp2d['aa'] # Low-low frequencies
273
ad_node = wp2d['ad'] # Low-high frequencies
274
da_node = wp2d['da'] # High-low frequencies
275
dd_node = wp2d['dd'] # High-high frequencies
276
277
print(f"Level 1 2D nodes:")
278
for path in ['aa', 'ad', 'da', 'dd']:
279
node = wp2d[path]
280
print(f" '{path}': shape {node.data.shape}")
281
282
# Perfect reconstruction
283
reconstructed_2d = wp2d.reconstruct()
284
print(f"2D reconstruction error: {np.max(np.abs(image - reconstructed_2d)):.2e}")
285
286
# Visualize 2D packet decomposition
287
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
288
289
axes[0, 0].imshow(image, cmap='gray')
290
axes[0, 0].set_title('Original Image')
291
axes[0, 0].axis('off')
292
293
axes[0, 1].imshow(reconstructed_2d, cmap='gray')
294
axes[0, 1].set_title('Reconstructed')
295
axes[0, 1].axis('off')
296
297
axes[0, 2].imshow(aa_node.data, cmap='gray')
298
axes[0, 2].set_title("'aa' - Low-Low")
299
axes[0, 2].axis('off')
300
301
axes[1, 0].imshow(np.abs(ad_node.data), cmap='gray')
302
axes[1, 0].set_title("'ad' - Low-High")
303
axes[1, 0].axis('off')
304
305
axes[1, 1].imshow(np.abs(da_node.data), cmap='gray')
306
axes[1, 1].set_title("'da' - High-Low")
307
axes[1, 1].axis('off')
308
309
axes[1, 2].imshow(np.abs(dd_node.data), cmap='gray')
310
axes[1, 2].set_title("'dd' - High-High")
311
axes[1, 2].axis('off')
312
313
plt.tight_layout()
314
plt.show()
315
316
# 2D packet-based image analysis
317
level_2_nodes = wp2d.get_level(2)
318
print(f"\nLevel 2 has {len(level_2_nodes)} nodes")
319
320
# Analyze energy distribution in frequency subbands
321
energy_analysis = {}
322
for node in level_2_nodes:
323
energy = np.sum(node.data**2)
324
energy_analysis[node.path] = energy
325
326
total_energy = sum(energy_analysis.values())
327
print("\nEnergy distribution at level 2:")
328
for path, energy in sorted(energy_analysis.items()):
329
percentage = energy / total_energy * 100
330
print(f" '{path}': {percentage:.2f}%")
331
```
332
333
### nD Wavelet Packet Transform
334
335
Wavelet packet transform for n-dimensional data.
336
337
```python { .api }
338
class WaveletPacketND:
339
def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None, axes=None):
340
"""
341
nD wavelet packet tree.
342
343
Parameters:
344
- data: Input nD array
345
- wavelet: Wavelet specification
346
- mode: Signal extension mode
347
- maxlevel: Maximum decomposition level
348
- axes: Axes along which to perform transform
349
"""
350
```
351
352
### Utility Functions
353
354
Helper functions for working with wavelet packet paths and ordering.
355
356
```python { .api }
357
def get_graycode_order(level: int, x: str = 'a', y: str = 'd') -> list:
358
"""
359
Generate graycode ordering for wavelet packet nodes.
360
361
Parameters:
362
- level: Decomposition level
363
- x: Symbol for first branch (default: 'a' for approximation)
364
- y: Symbol for second branch (default: 'd' for detail)
365
366
Returns:
367
List of path strings in graycode order
368
"""
369
```
370
371
### Base Node Classes
372
373
```python { .api }
374
class BaseNode:
375
"""Base class for wavelet packet nodes."""
376
377
def __init__(self, parent, path, data):
378
"""
379
Initialize base node.
380
381
Parameters:
382
- parent: Parent node
383
- path: Node path string
384
- data: Node data
385
"""
386
387
class Node:
388
"""1D wavelet packet node."""
389
390
def __init__(self, parent, path, data):
391
"""Initialize 1D node."""
392
393
class Node2D:
394
"""2D wavelet packet node."""
395
396
def __init__(self, parent, path, data):
397
"""Initialize 2D node."""
398
399
class NodeND:
400
"""nD wavelet packet node."""
401
402
def __init__(self, parent, path, data):
403
"""Initialize nD node."""
404
```
405
406
#### Advanced Usage Examples
407
408
```python
409
import pywt
410
import numpy as np
411
412
# Advanced wavelet packet analysis
413
def analyze_packet_tree(wp, max_level=None):
414
"""Analyze complete wavelet packet tree structure."""
415
if max_level is None:
416
max_level = wp.maxlevel
417
418
analysis = {}
419
420
for level in range(max_level + 1):
421
nodes = wp.get_level(level)
422
level_info = {
423
'num_nodes': len(nodes),
424
'paths': [node.path for node in nodes],
425
'shapes': [node.data.shape for node in nodes],
426
'energies': [np.sum(node.data**2) for node in nodes]
427
}
428
analysis[level] = level_info
429
430
return analysis
431
432
# Test with complex signal
433
complex_signal = np.zeros(2048)
434
# Add multiple frequency components at different time intervals
435
for i, (freq, start, end) in enumerate([(5, 0, 512), (20, 512, 1024),
436
(50, 1024, 1536), (100, 1536, 2048)]):
437
t_segment = np.linspace(0, 1, end - start)
438
complex_signal[start:end] = np.sin(2 * np.pi * freq * t_segment)
439
440
wp_complex = pywt.WaveletPacket(complex_signal, 'db8', maxlevel=6)
441
tree_analysis = analyze_packet_tree(wp_complex, max_level=4)
442
443
print("Wavelet Packet Tree Analysis:")
444
for level, info in tree_analysis.items():
445
print(f"\nLevel {level}:")
446
print(f" Number of nodes: {info['num_nodes']}")
447
print(f" Paths: {info['paths']}")
448
total_energy = sum(info['energies'])
449
print(f" Total energy: {total_energy:.2f}")
450
451
# Show energy distribution
452
if total_energy > 0:
453
for path, energy in zip(info['paths'], info['energies']):
454
percentage = energy / total_energy * 100
455
if percentage > 5: # Show only significant components
456
print(f" '{path}': {percentage:.1f}%")
457
458
# Best basis selection using cost function
459
def entropy_cost(data):
460
"""Calculate entropy-based cost for best basis selection."""
461
# Normalized energy
462
energy = data**2
463
total_energy = np.sum(energy)
464
if total_energy == 0:
465
return 0
466
467
p = energy / total_energy
468
p = p[p > 0] # Avoid log(0)
469
return -np.sum(p * np.log2(p))
470
471
def find_best_basis_entropy(wp, max_level):
472
"""Find best basis using entropy criterion."""
473
def node_cost(node):
474
return entropy_cost(node.data)
475
476
best_basis = []
477
478
def select_basis(node):
479
if node.level >= max_level:
480
best_basis.append(node.path)
481
return node_cost(node)
482
483
# Calculate cost for current node
484
current_cost = node_cost(node)
485
486
# Calculate cost for children
487
node.decompose()
488
child_paths = [node.path + 'a', node.path + 'd']
489
child_cost = sum(select_basis(wp[path]) for path in child_paths)
490
491
if current_cost <= child_cost:
492
# Remove children from basis and add current node
493
best_basis[:] = [path for path in best_basis
494
if not any(path.startswith(cp) for cp in child_paths)]
495
best_basis.append(node.path)
496
return current_cost
497
else:
498
return child_cost
499
500
select_basis(wp)
501
return best_basis
502
503
entropy_basis = find_best_basis_entropy(wp_complex, 5)
504
print(f"\nBest basis (entropy): {entropy_basis}")
505
print(f"Number of basis functions: {len(entropy_basis)}")
506
507
# Adaptive compression using packet transform
508
def adaptive_compress(wp, compression_ratio=0.1):
509
"""Compress signal by keeping only largest coefficients."""
510
# Collect all coefficients
511
all_coeffs = []
512
all_paths = []
513
514
for level in range(wp.maxlevel + 1):
515
nodes = wp.get_level(level)
516
for node in nodes:
517
if node.path: # Skip root
518
all_coeffs.extend(node.data.ravel())
519
all_paths.extend([node.path] * len(node.data))
520
521
all_coeffs = np.array(all_coeffs)
522
523
# Find threshold for desired compression ratio
524
num_keep = int(len(all_coeffs) * compression_ratio)
525
threshold = np.sort(np.abs(all_coeffs))[-num_keep]
526
527
# Apply threshold to all nodes
528
compressed_wp = pywt.WaveletPacket(wp.data, wp.wavelet, wp.mode, wp.maxlevel)
529
530
for level in range(wp.maxlevel + 1):
531
nodes = wp.get_level(level)
532
for node in nodes:
533
if node.path: # Skip root
534
compressed_node = compressed_wp[node.path]
535
compressed_node.data = np.where(np.abs(node.data) >= threshold,
536
node.data, 0)
537
538
return compressed_wp.reconstruct()
539
540
compressed_signal = adaptive_compress(wp_complex, compression_ratio=0.05)
541
compression_error = np.sqrt(np.mean((complex_signal - compressed_signal)**2))
542
print(f"\nCompression (5% coefficients):")
543
print(f"RMSE: {compression_error:.4f}")
544
print(f"SNR: {20 * np.log10(np.std(complex_signal) / compression_error):.2f} dB")
545
```
546
547
## Types
548
549
```python { .api }
550
# Wavelet packet path specification
551
PacketPath = str # Path string like '', 'a', 'd', 'aa', 'ad', 'da', 'dd', etc.
552
553
# Node ordering for level extraction
554
NodeOrder = Literal['natural', 'freq']
555
556
# Wavelet packet node types
557
WaveletPacketNode = Union[WaveletPacket, Node, Node2D, NodeND]
558
```