Discrete and continuous wavelet transforms for signal and image processing with comprehensive 1D, 2D, and nD transform support.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Wavelet packet decomposition providing complete binary tree analysis for detailed frequency localization and adaptive basis selection with both 1D and 2D implementations.
Complete binary tree decomposition for 1D signals with node-based navigation.
class WaveletPacket:
def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None):
"""
1D wavelet packet tree.
Parameters:
- data: Input 1D signal
- wavelet: Wavelet specification
- mode: Signal extension mode
- maxlevel: Maximum decomposition level (default: automatic)
"""
# Properties
data: np.ndarray # Node data
path: str # Node path ('' for root, 'a' for approximation child, 'd' for detail child)
level: int # Decomposition level
maxlevel: int # Maximum allowed level
mode: str # Signal extension mode
wavelet: Wavelet # Wavelet object
def decompose(self):
"""Decompose current node into approximation and detail children."""
def reconstruct(self, update: bool = False):
"""
Reconstruct node data from child nodes.
Parameters:
- update: If True, update node data with reconstructed values
Returns:
Reconstructed data array
"""
def __getitem__(self, path: str):
"""
Get node at specified path.
Parameters:
- path: Path string ('', 'a', 'd', 'aa', 'ad', 'da', 'dd', etc.)
Returns:
Node at specified path
"""
def __setitem__(self, path: str, data):
"""Set data for node at specified path."""
def get_level(self, level: int, order: str = 'natural', decompose: bool = True):
"""
Get all nodes at specified level.
Parameters:
- level: Decomposition level
- order: Node ordering ('natural' or 'freq')
- decompose: Whether to decompose nodes if needed
Returns:
List of nodes at specified level
"""import pywt
import numpy as np
import matplotlib.pyplot as plt
# Create test signal with multiple frequency components
np.random.seed(42)
t = np.linspace(0, 1, 1024)
signal = (np.sin(2 * np.pi * 8 * t) + # 8 Hz component
0.7 * np.sin(2 * np.pi * 32 * t) + # 32 Hz component
0.5 * np.sin(2 * np.pi * 64 * t) + # 64 Hz component
0.2 * np.random.randn(len(t))) # Noise
print(f"Signal length: {len(signal)}")
# Create wavelet packet tree
wp = pywt.WaveletPacket(signal, 'db4', mode='symmetric', maxlevel=6)
print(f"Wavelet packet created with maxlevel: {wp.maxlevel}")
print(f"Root node path: '{wp.path}', level: {wp.level}")
# Navigate the tree structure
print(f"\nTree navigation:")
print(f"Root data shape: {wp.data.shape}")
print(f"Root node: '{wp.path}' -> {wp.data.shape}")
# Access child nodes (automatically decomposes if needed)
approx_child = wp['a'] # Approximation child
detail_child = wp['d'] # Detail child
print(f"Approximation child: '{approx_child.path}' -> {approx_child.data.shape}")
print(f"Detail child: '{detail_child.path}' -> {detail_child.data.shape}")
# Go deeper into the tree
aa_node = wp['aa'] # Approximation of approximation
ad_node = wp['ad'] # Detail of approximation
da_node = wp['da'] # Approximation of detail
dd_node = wp['dd'] # Detail of detail
print(f"Level 2 nodes:")
for path in ['aa', 'ad', 'da', 'dd']:
node = wp[path]
print(f" '{path}': level {node.level}, shape {node.data.shape}")
# Get all nodes at a specific level
level_3_nodes = wp.get_level(3, order='natural')
print(f"\nLevel 3 nodes: {len(level_3_nodes)}")
for node in level_3_nodes:
print(f" Path: '{node.path}', shape: {node.data.shape}")
# Frequency ordering
level_3_freq = wp.get_level(3, order='freq')
print(f"Level 3 (frequency order): {[node.path for node in level_3_freq]}")
# Reconstruct signal from packet decomposition
reconstructed = wp.reconstruct()
reconstruction_error = np.max(np.abs(signal - reconstructed))
print(f"\nReconstruction error: {reconstruction_error:.2e}")
# Best basis selection (energy-based)
def calculate_node_energy(node):
"""Calculate energy of a node."""
return np.sum(node.data**2)
def find_best_basis_energy(wp, max_level):
"""Find best basis based on energy concentration."""
best_basis = []
def traverse_node(node):
if node.level >= max_level:
best_basis.append(node.path)
return calculate_node_energy(node)
# Decompose and calculate energy for children
node.decompose()
approx_child = wp[node.path + 'a']
detail_child = wp[node.path + 'd']
child_energy = traverse_node(approx_child) + traverse_node(detail_child)
current_energy = calculate_node_energy(node)
if current_energy >= child_energy:
# Current node has better energy concentration
best_basis = [path for path in best_basis if not path.startswith(node.path + 'a') and not path.startswith(node.path + 'd')]
best_basis.append(node.path)
return current_energy
else:
return child_energy
traverse_node(wp)
return best_basis
best_basis = find_best_basis_energy(wp, 4)
print(f"\nBest basis (energy): {best_basis}")
# Visualize wavelet packet decomposition
fig, axes = plt.subplots(4, 1, figsize=(15, 12))
# Original signal
axes[0].plot(t, signal)
axes[0].set_title('Original Signal')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True)
# Level 2 decomposition
level_2_nodes = wp.get_level(2)
for i, node in enumerate(level_2_nodes):
if i < 3: # Show first 3 nodes
axes[i+1].plot(node.data)
axes[i+1].set_title(f"Node '{node.path}' (Level {node.level})")
axes[i+1].set_ylabel('Amplitude')
axes[i+1].grid(True)
plt.tight_layout()
plt.show()
# Packet-based denoising
def packet_denoise(wp, threshold_factor=0.1):
"""Denoise using wavelet packet thresholding."""
# Get all leaf nodes at maximum decomposition
leaf_nodes = wp.get_level(wp.maxlevel)
for node in leaf_nodes:
if node.path != '': # Don't threshold root
# Calculate adaptive threshold based on node statistics
threshold = threshold_factor * np.std(node.data)
node.data = pywt.threshold(node.data, threshold, mode='soft')
return wp.reconstruct()
# Create noisy version and denoise
noisy_signal = signal + 0.4 * np.random.randn(len(signal))
wp_noisy = pywt.WaveletPacket(noisy_signal, 'db4', maxlevel=5)
denoised_signal = packet_denoise(wp_noisy, threshold_factor=0.2)
# Compare denoising methods
regular_coeffs = pywt.wavedec(noisy_signal, 'db4', level=5)
regular_thresh = [regular_coeffs[0]] # Keep approximation
for detail in regular_coeffs[1:]:
threshold = 0.2 * np.std(detail)
regular_thresh.append(pywt.threshold(detail, threshold, mode='soft'))
regular_denoised = pywt.waverec(regular_thresh, 'db4')
print(f"\nDenoising comparison:")
packet_snr = 10 * np.log10(np.var(signal) / np.var(signal - denoised_signal))
regular_snr = 10 * np.log10(np.var(signal) / np.var(signal - regular_denoised))
print(f"Wavelet packet denoising SNR: {packet_snr:.2f} dB")
print(f"Regular wavelet denoising SNR: {regular_snr:.2f} dB")Complete binary tree decomposition for 2D data such as images.
class WaveletPacket2D:
def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None):
"""
2D wavelet packet tree.
Parameters:
- data: Input 2D array
- wavelet: Wavelet specification
- mode: Signal extension mode
- maxlevel: Maximum decomposition level
"""
# Similar interface to WaveletPacket but for 2D data
# Path strings use combinations of 'a' and 'd' for each dimension
# e.g., 'aa', 'ad', 'da', 'dd' for level 1
# 'aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd' for level 2import pywt
import numpy as np
import matplotlib.pyplot as plt
# Create test image
size = 256
x, y = np.mgrid[0:size, 0:size]
image = np.zeros((size, size))
# Add geometric patterns
image[64:192, 64:192] = 1.0
image[96:160, 96:160] = 0.5
# Add some texture
image += 0.1 * np.sin(2 * np.pi * x / 32) * np.cos(2 * np.pi * y / 32)
print(f"Image shape: {image.shape}")
# Create 2D wavelet packet
wp2d = pywt.WaveletPacket2D(image, 'db2', maxlevel=4)
print(f"2D Wavelet packet maxlevel: {wp2d.maxlevel}")
# Access 2D packet nodes
print(f"\n2D packet navigation:")
aa_node = wp2d['aa'] # Low-low frequencies
ad_node = wp2d['ad'] # Low-high frequencies
da_node = wp2d['da'] # High-low frequencies
dd_node = wp2d['dd'] # High-high frequencies
print(f"Level 1 2D nodes:")
for path in ['aa', 'ad', 'da', 'dd']:
node = wp2d[path]
print(f" '{path}': shape {node.data.shape}")
# Perfect reconstruction
reconstructed_2d = wp2d.reconstruct()
print(f"2D reconstruction error: {np.max(np.abs(image - reconstructed_2d)):.2e}")
# Visualize 2D packet decomposition
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes[0, 0].imshow(image, cmap='gray')
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
axes[0, 1].imshow(reconstructed_2d, cmap='gray')
axes[0, 1].set_title('Reconstructed')
axes[0, 1].axis('off')
axes[0, 2].imshow(aa_node.data, cmap='gray')
axes[0, 2].set_title("'aa' - Low-Low")
axes[0, 2].axis('off')
axes[1, 0].imshow(np.abs(ad_node.data), cmap='gray')
axes[1, 0].set_title("'ad' - Low-High")
axes[1, 0].axis('off')
axes[1, 1].imshow(np.abs(da_node.data), cmap='gray')
axes[1, 1].set_title("'da' - High-Low")
axes[1, 1].axis('off')
axes[1, 2].imshow(np.abs(dd_node.data), cmap='gray')
axes[1, 2].set_title("'dd' - High-High")
axes[1, 2].axis('off')
plt.tight_layout()
plt.show()
# 2D packet-based image analysis
level_2_nodes = wp2d.get_level(2)
print(f"\nLevel 2 has {len(level_2_nodes)} nodes")
# Analyze energy distribution in frequency subbands
energy_analysis = {}
for node in level_2_nodes:
energy = np.sum(node.data**2)
energy_analysis[node.path] = energy
total_energy = sum(energy_analysis.values())
print("\nEnergy distribution at level 2:")
for path, energy in sorted(energy_analysis.items()):
percentage = energy / total_energy * 100
print(f" '{path}': {percentage:.2f}%")Wavelet packet transform for n-dimensional data.
class WaveletPacketND:
def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None, axes=None):
"""
nD wavelet packet tree.
Parameters:
- data: Input nD array
- wavelet: Wavelet specification
- mode: Signal extension mode
- maxlevel: Maximum decomposition level
- axes: Axes along which to perform transform
"""Helper functions for working with wavelet packet paths and ordering.
def get_graycode_order(level: int, x: str = 'a', y: str = 'd') -> list:
"""
Generate graycode ordering for wavelet packet nodes.
Parameters:
- level: Decomposition level
- x: Symbol for first branch (default: 'a' for approximation)
- y: Symbol for second branch (default: 'd' for detail)
Returns:
List of path strings in graycode order
"""class BaseNode:
"""Base class for wavelet packet nodes."""
def __init__(self, parent, path, data):
"""
Initialize base node.
Parameters:
- parent: Parent node
- path: Node path string
- data: Node data
"""
class Node:
"""1D wavelet packet node."""
def __init__(self, parent, path, data):
"""Initialize 1D node."""
class Node2D:
"""2D wavelet packet node."""
def __init__(self, parent, path, data):
"""Initialize 2D node."""
class NodeND:
"""nD wavelet packet node."""
def __init__(self, parent, path, data):
"""Initialize nD node."""import pywt
import numpy as np
# Advanced wavelet packet analysis
def analyze_packet_tree(wp, max_level=None):
"""Analyze complete wavelet packet tree structure."""
if max_level is None:
max_level = wp.maxlevel
analysis = {}
for level in range(max_level + 1):
nodes = wp.get_level(level)
level_info = {
'num_nodes': len(nodes),
'paths': [node.path for node in nodes],
'shapes': [node.data.shape for node in nodes],
'energies': [np.sum(node.data**2) for node in nodes]
}
analysis[level] = level_info
return analysis
# Test with complex signal
complex_signal = np.zeros(2048)
# Add multiple frequency components at different time intervals
for i, (freq, start, end) in enumerate([(5, 0, 512), (20, 512, 1024),
(50, 1024, 1536), (100, 1536, 2048)]):
t_segment = np.linspace(0, 1, end - start)
complex_signal[start:end] = np.sin(2 * np.pi * freq * t_segment)
wp_complex = pywt.WaveletPacket(complex_signal, 'db8', maxlevel=6)
tree_analysis = analyze_packet_tree(wp_complex, max_level=4)
print("Wavelet Packet Tree Analysis:")
for level, info in tree_analysis.items():
print(f"\nLevel {level}:")
print(f" Number of nodes: {info['num_nodes']}")
print(f" Paths: {info['paths']}")
total_energy = sum(info['energies'])
print(f" Total energy: {total_energy:.2f}")
# Show energy distribution
if total_energy > 0:
for path, energy in zip(info['paths'], info['energies']):
percentage = energy / total_energy * 100
if percentage > 5: # Show only significant components
print(f" '{path}': {percentage:.1f}%")
# Best basis selection using cost function
def entropy_cost(data):
"""Calculate entropy-based cost for best basis selection."""
# Normalized energy
energy = data**2
total_energy = np.sum(energy)
if total_energy == 0:
return 0
p = energy / total_energy
p = p[p > 0] # Avoid log(0)
return -np.sum(p * np.log2(p))
def find_best_basis_entropy(wp, max_level):
"""Find best basis using entropy criterion."""
def node_cost(node):
return entropy_cost(node.data)
best_basis = []
def select_basis(node):
if node.level >= max_level:
best_basis.append(node.path)
return node_cost(node)
# Calculate cost for current node
current_cost = node_cost(node)
# Calculate cost for children
node.decompose()
child_paths = [node.path + 'a', node.path + 'd']
child_cost = sum(select_basis(wp[path]) for path in child_paths)
if current_cost <= child_cost:
# Remove children from basis and add current node
best_basis[:] = [path for path in best_basis
if not any(path.startswith(cp) for cp in child_paths)]
best_basis.append(node.path)
return current_cost
else:
return child_cost
select_basis(wp)
return best_basis
entropy_basis = find_best_basis_entropy(wp_complex, 5)
print(f"\nBest basis (entropy): {entropy_basis}")
print(f"Number of basis functions: {len(entropy_basis)}")
# Adaptive compression using packet transform
def adaptive_compress(wp, compression_ratio=0.1):
"""Compress signal by keeping only largest coefficients."""
# Collect all coefficients
all_coeffs = []
all_paths = []
for level in range(wp.maxlevel + 1):
nodes = wp.get_level(level)
for node in nodes:
if node.path: # Skip root
all_coeffs.extend(node.data.ravel())
all_paths.extend([node.path] * len(node.data))
all_coeffs = np.array(all_coeffs)
# Find threshold for desired compression ratio
num_keep = int(len(all_coeffs) * compression_ratio)
threshold = np.sort(np.abs(all_coeffs))[-num_keep]
# Apply threshold to all nodes
compressed_wp = pywt.WaveletPacket(wp.data, wp.wavelet, wp.mode, wp.maxlevel)
for level in range(wp.maxlevel + 1):
nodes = wp.get_level(level)
for node in nodes:
if node.path: # Skip root
compressed_node = compressed_wp[node.path]
compressed_node.data = np.where(np.abs(node.data) >= threshold,
node.data, 0)
return compressed_wp.reconstruct()
compressed_signal = adaptive_compress(wp_complex, compression_ratio=0.05)
compression_error = np.sqrt(np.mean((complex_signal - compressed_signal)**2))
print(f"\nCompression (5% coefficients):")
print(f"RMSE: {compression_error:.4f}")
print(f"SNR: {20 * np.log10(np.std(complex_signal) / compression_error):.2f} dB")# Wavelet packet path specification
PacketPath = str # Path string like '', 'a', 'd', 'aa', 'ad', 'da', 'dd', etc.
# Node ordering for level extraction
NodeOrder = Literal['natural', 'freq']
# Wavelet packet node types
WaveletPacketNode = Union[WaveletPacket, Node, Node2D, NodeND]Install with Tessl CLI
npx tessl i tessl/pypi-pywavelets