CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pywavelets

Discrete and continuous wavelet transforms for signal and image processing with comprehensive 1D, 2D, and nD transform support.

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

wavelet-packets.mddocs/

Wavelet Packet Transform

Wavelet packet decomposition providing complete binary tree analysis for detailed frequency localization and adaptive basis selection with both 1D and 2D implementations.

Capabilities

1D Wavelet Packet Transform

Complete binary tree decomposition for 1D signals with node-based navigation.

class WaveletPacket:
    def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None):
        """
        1D wavelet packet tree.
        
        Parameters:
        - data: Input 1D signal
        - wavelet: Wavelet specification
        - mode: Signal extension mode
        - maxlevel: Maximum decomposition level (default: automatic)
        """
    
    # Properties
    data: np.ndarray  # Node data
    path: str  # Node path ('' for root, 'a' for approximation child, 'd' for detail child)
    level: int  # Decomposition level
    maxlevel: int  # Maximum allowed level
    mode: str  # Signal extension mode
    wavelet: Wavelet  # Wavelet object
    
    def decompose(self):
        """Decompose current node into approximation and detail children."""
    
    def reconstruct(self, update: bool = False):
        """
        Reconstruct node data from child nodes.
        
        Parameters:
        - update: If True, update node data with reconstructed values
        
        Returns:
        Reconstructed data array
        """
    
    def __getitem__(self, path: str):
        """
        Get node at specified path.
        
        Parameters:
        - path: Path string ('', 'a', 'd', 'aa', 'ad', 'da', 'dd', etc.)
        
        Returns:
        Node at specified path
        """
    
    def __setitem__(self, path: str, data):
        """Set data for node at specified path."""
        
    def get_level(self, level: int, order: str = 'natural', decompose: bool = True):
        """
        Get all nodes at specified level.
        
        Parameters:
        - level: Decomposition level
        - order: Node ordering ('natural' or 'freq')
        - decompose: Whether to decompose nodes if needed
        
        Returns:
        List of nodes at specified level
        """

Usage Examples

import pywt
import numpy as np
import matplotlib.pyplot as plt

# Create test signal with multiple frequency components
np.random.seed(42)
t = np.linspace(0, 1, 1024)
signal = (np.sin(2 * np.pi * 8 * t) +           # 8 Hz component
          0.7 * np.sin(2 * np.pi * 32 * t) +    # 32 Hz component  
          0.5 * np.sin(2 * np.pi * 64 * t) +    # 64 Hz component
          0.2 * np.random.randn(len(t)))        # Noise

print(f"Signal length: {len(signal)}")

# Create wavelet packet tree
wp = pywt.WaveletPacket(signal, 'db4', mode='symmetric', maxlevel=6)
print(f"Wavelet packet created with maxlevel: {wp.maxlevel}")
print(f"Root node path: '{wp.path}', level: {wp.level}")

# Navigate the tree structure
print(f"\nTree navigation:")
print(f"Root data shape: {wp.data.shape}")
print(f"Root node: '{wp.path}' -> {wp.data.shape}")

# Access child nodes (automatically decomposes if needed)
approx_child = wp['a']  # Approximation child
detail_child = wp['d']  # Detail child
print(f"Approximation child: '{approx_child.path}' -> {approx_child.data.shape}")
print(f"Detail child: '{detail_child.path}' -> {detail_child.data.shape}")

# Go deeper into the tree
aa_node = wp['aa']  # Approximation of approximation
ad_node = wp['ad']  # Detail of approximation
da_node = wp['da']  # Approximation of detail
dd_node = wp['dd']  # Detail of detail

print(f"Level 2 nodes:")
for path in ['aa', 'ad', 'da', 'dd']:
    node = wp[path]
    print(f"  '{path}': level {node.level}, shape {node.data.shape}")

# Get all nodes at a specific level
level_3_nodes = wp.get_level(3, order='natural')
print(f"\nLevel 3 nodes: {len(level_3_nodes)}")
for node in level_3_nodes:
    print(f"  Path: '{node.path}', shape: {node.data.shape}")

# Frequency ordering
level_3_freq = wp.get_level(3, order='freq')
print(f"Level 3 (frequency order): {[node.path for node in level_3_freq]}")

# Reconstruct signal from packet decomposition
reconstructed = wp.reconstruct()
reconstruction_error = np.max(np.abs(signal - reconstructed))
print(f"\nReconstruction error: {reconstruction_error:.2e}")

# Best basis selection (energy-based)
def calculate_node_energy(node):
    """Calculate energy of a node."""
    return np.sum(node.data**2)

def find_best_basis_energy(wp, max_level):
    """Find best basis based on energy concentration."""
    best_basis = []
    
    def traverse_node(node):
        if node.level >= max_level:
            best_basis.append(node.path)
            return calculate_node_energy(node)
        
        # Decompose and calculate energy for children
        node.decompose()
        approx_child = wp[node.path + 'a']
        detail_child = wp[node.path + 'd']
        
        child_energy = traverse_node(approx_child) + traverse_node(detail_child)
        current_energy = calculate_node_energy(node)
        
        if current_energy >= child_energy:
            # Current node has better energy concentration
            best_basis = [path for path in best_basis if not path.startswith(node.path + 'a') and not path.startswith(node.path + 'd')]
            best_basis.append(node.path)
            return current_energy
        else:
            return child_energy
    
    traverse_node(wp)
    return best_basis

best_basis = find_best_basis_energy(wp, 4)
print(f"\nBest basis (energy): {best_basis}")

# Visualize wavelet packet decomposition
fig, axes = plt.subplots(4, 1, figsize=(15, 12))

# Original signal
axes[0].plot(t, signal)
axes[0].set_title('Original Signal')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True)

# Level 2 decomposition
level_2_nodes = wp.get_level(2)
for i, node in enumerate(level_2_nodes):
    if i < 3:  # Show first 3 nodes
        axes[i+1].plot(node.data)
        axes[i+1].set_title(f"Node '{node.path}' (Level {node.level})")
        axes[i+1].set_ylabel('Amplitude')
        axes[i+1].grid(True)

plt.tight_layout()
plt.show()

# Packet-based denoising
def packet_denoise(wp, threshold_factor=0.1):
    """Denoise using wavelet packet thresholding."""
    # Get all leaf nodes at maximum decomposition
    leaf_nodes = wp.get_level(wp.maxlevel)
    
    for node in leaf_nodes:
        if node.path != '':  # Don't threshold root
            # Calculate adaptive threshold based on node statistics
            threshold = threshold_factor * np.std(node.data)
            node.data = pywt.threshold(node.data, threshold, mode='soft')
    
    return wp.reconstruct()

# Create noisy version and denoise
noisy_signal = signal + 0.4 * np.random.randn(len(signal))
wp_noisy = pywt.WaveletPacket(noisy_signal, 'db4', maxlevel=5)
denoised_signal = packet_denoise(wp_noisy, threshold_factor=0.2)

# Compare denoising methods
regular_coeffs = pywt.wavedec(noisy_signal, 'db4', level=5)
regular_thresh = [regular_coeffs[0]]  # Keep approximation
for detail in regular_coeffs[1:]:
    threshold = 0.2 * np.std(detail)
    regular_thresh.append(pywt.threshold(detail, threshold, mode='soft'))
regular_denoised = pywt.waverec(regular_thresh, 'db4')

print(f"\nDenoising comparison:")
packet_snr = 10 * np.log10(np.var(signal) / np.var(signal - denoised_signal))
regular_snr = 10 * np.log10(np.var(signal) / np.var(signal - regular_denoised))
print(f"Wavelet packet denoising SNR: {packet_snr:.2f} dB")
print(f"Regular wavelet denoising SNR: {regular_snr:.2f} dB")

2D Wavelet Packet Transform

Complete binary tree decomposition for 2D data such as images.

class WaveletPacket2D:
    def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None):
        """
        2D wavelet packet tree.
        
        Parameters:
        - data: Input 2D array
        - wavelet: Wavelet specification
        - mode: Signal extension mode
        - maxlevel: Maximum decomposition level
        """
    
    # Similar interface to WaveletPacket but for 2D data
    # Path strings use combinations of 'a' and 'd' for each dimension
    # e.g., 'aa', 'ad', 'da', 'dd' for level 1
    #       'aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd' for level 2

Usage Examples

import pywt
import numpy as np
import matplotlib.pyplot as plt

# Create test image
size = 256
x, y = np.mgrid[0:size, 0:size]
image = np.zeros((size, size))

# Add geometric patterns
image[64:192, 64:192] = 1.0
image[96:160, 96:160] = 0.5
# Add some texture
image += 0.1 * np.sin(2 * np.pi * x / 32) * np.cos(2 * np.pi * y / 32)

print(f"Image shape: {image.shape}")

# Create 2D wavelet packet
wp2d = pywt.WaveletPacket2D(image, 'db2', maxlevel=4)
print(f"2D Wavelet packet maxlevel: {wp2d.maxlevel}")

# Access 2D packet nodes
print(f"\n2D packet navigation:")
aa_node = wp2d['aa']  # Low-low frequencies
ad_node = wp2d['ad']  # Low-high frequencies
da_node = wp2d['da']  # High-low frequencies
dd_node = wp2d['dd']  # High-high frequencies

print(f"Level 1 2D nodes:")
for path in ['aa', 'ad', 'da', 'dd']:
    node = wp2d[path]
    print(f"  '{path}': shape {node.data.shape}")

# Perfect reconstruction
reconstructed_2d = wp2d.reconstruct()
print(f"2D reconstruction error: {np.max(np.abs(image - reconstructed_2d)):.2e}")

# Visualize 2D packet decomposition
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(image, cmap='gray')
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(reconstructed_2d, cmap='gray')
axes[0, 1].set_title('Reconstructed')
axes[0, 1].axis('off')

axes[0, 2].imshow(aa_node.data, cmap='gray')
axes[0, 2].set_title("'aa' - Low-Low")
axes[0, 2].axis('off')

axes[1, 0].imshow(np.abs(ad_node.data), cmap='gray')
axes[1, 0].set_title("'ad' - Low-High")
axes[1, 0].axis('off')

axes[1, 1].imshow(np.abs(da_node.data), cmap='gray')
axes[1, 1].set_title("'da' - High-Low")
axes[1, 1].axis('off')

axes[1, 2].imshow(np.abs(dd_node.data), cmap='gray')
axes[1, 2].set_title("'dd' - High-High")
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

# 2D packet-based image analysis
level_2_nodes = wp2d.get_level(2)
print(f"\nLevel 2 has {len(level_2_nodes)} nodes")

# Analyze energy distribution in frequency subbands
energy_analysis = {}
for node in level_2_nodes:
    energy = np.sum(node.data**2)
    energy_analysis[node.path] = energy

total_energy = sum(energy_analysis.values())
print("\nEnergy distribution at level 2:")
for path, energy in sorted(energy_analysis.items()):
    percentage = energy / total_energy * 100
    print(f"  '{path}': {percentage:.2f}%")

nD Wavelet Packet Transform

Wavelet packet transform for n-dimensional data.

class WaveletPacketND:
    def __init__(self, data, wavelet, mode: str = 'symmetric', maxlevel=None, axes=None):
        """
        nD wavelet packet tree.
        
        Parameters:
        - data: Input nD array
        - wavelet: Wavelet specification
        - mode: Signal extension mode
        - maxlevel: Maximum decomposition level
        - axes: Axes along which to perform transform
        """

Utility Functions

Helper functions for working with wavelet packet paths and ordering.

def get_graycode_order(level: int, x: str = 'a', y: str = 'd') -> list:
    """
    Generate graycode ordering for wavelet packet nodes.
    
    Parameters:
    - level: Decomposition level
    - x: Symbol for first branch (default: 'a' for approximation)
    - y: Symbol for second branch (default: 'd' for detail)
    
    Returns:
    List of path strings in graycode order
    """

Base Node Classes

class BaseNode:
    """Base class for wavelet packet nodes."""
    
    def __init__(self, parent, path, data):
        """
        Initialize base node.
        
        Parameters:
        - parent: Parent node
        - path: Node path string
        - data: Node data
        """

class Node:
    """1D wavelet packet node."""
    
    def __init__(self, parent, path, data):
        """Initialize 1D node."""

class Node2D:
    """2D wavelet packet node."""
    
    def __init__(self, parent, path, data):
        """Initialize 2D node."""

class NodeND:
    """nD wavelet packet node."""
    
    def __init__(self, parent, path, data):
        """Initialize nD node."""

Advanced Usage Examples

import pywt
import numpy as np

# Advanced wavelet packet analysis
def analyze_packet_tree(wp, max_level=None):
    """Analyze complete wavelet packet tree structure."""
    if max_level is None:
        max_level = wp.maxlevel
    
    analysis = {}
    
    for level in range(max_level + 1):
        nodes = wp.get_level(level)
        level_info = {
            'num_nodes': len(nodes),
            'paths': [node.path for node in nodes],
            'shapes': [node.data.shape for node in nodes],
            'energies': [np.sum(node.data**2) for node in nodes]
        }
        analysis[level] = level_info
    
    return analysis

# Test with complex signal
complex_signal = np.zeros(2048)
# Add multiple frequency components at different time intervals
for i, (freq, start, end) in enumerate([(5, 0, 512), (20, 512, 1024), 
                                       (50, 1024, 1536), (100, 1536, 2048)]):
    t_segment = np.linspace(0, 1, end - start)
    complex_signal[start:end] = np.sin(2 * np.pi * freq * t_segment)

wp_complex = pywt.WaveletPacket(complex_signal, 'db8', maxlevel=6)
tree_analysis = analyze_packet_tree(wp_complex, max_level=4)

print("Wavelet Packet Tree Analysis:")
for level, info in tree_analysis.items():
    print(f"\nLevel {level}:")
    print(f"  Number of nodes: {info['num_nodes']}")
    print(f"  Paths: {info['paths']}")
    total_energy = sum(info['energies'])
    print(f"  Total energy: {total_energy:.2f}")
    
    # Show energy distribution
    if total_energy > 0:
        for path, energy in zip(info['paths'], info['energies']):
            percentage = energy / total_energy * 100
            if percentage > 5:  # Show only significant components
                print(f"    '{path}': {percentage:.1f}%")

# Best basis selection using cost function
def entropy_cost(data):
    """Calculate entropy-based cost for best basis selection."""
    # Normalized energy
    energy = data**2
    total_energy = np.sum(energy)
    if total_energy == 0:
        return 0
    
    p = energy / total_energy
    p = p[p > 0]  # Avoid log(0)
    return -np.sum(p * np.log2(p))

def find_best_basis_entropy(wp, max_level):
    """Find best basis using entropy criterion."""
    def node_cost(node):
        return entropy_cost(node.data)
    
    best_basis = []
    
    def select_basis(node):
        if node.level >= max_level:
            best_basis.append(node.path)
            return node_cost(node)
        
        # Calculate cost for current node
        current_cost = node_cost(node)
        
        # Calculate cost for children
        node.decompose()
        child_paths = [node.path + 'a', node.path + 'd']
        child_cost = sum(select_basis(wp[path]) for path in child_paths)
        
        if current_cost <= child_cost:
            # Remove children from basis and add current node
            best_basis[:] = [path for path in best_basis 
                           if not any(path.startswith(cp) for cp in child_paths)]
            best_basis.append(node.path)
            return current_cost
        else:
            return child_cost
    
    select_basis(wp)
    return best_basis

entropy_basis = find_best_basis_entropy(wp_complex, 5)
print(f"\nBest basis (entropy): {entropy_basis}")
print(f"Number of basis functions: {len(entropy_basis)}")

# Adaptive compression using packet transform
def adaptive_compress(wp, compression_ratio=0.1):
    """Compress signal by keeping only largest coefficients."""
    # Collect all coefficients
    all_coeffs = []
    all_paths = []
    
    for level in range(wp.maxlevel + 1):
        nodes = wp.get_level(level)
        for node in nodes:
            if node.path:  # Skip root
                all_coeffs.extend(node.data.ravel())
                all_paths.extend([node.path] * len(node.data))
    
    all_coeffs = np.array(all_coeffs)
    
    # Find threshold for desired compression ratio
    num_keep = int(len(all_coeffs) * compression_ratio)
    threshold = np.sort(np.abs(all_coeffs))[-num_keep]
    
    # Apply threshold to all nodes
    compressed_wp = pywt.WaveletPacket(wp.data, wp.wavelet, wp.mode, wp.maxlevel)
    
    for level in range(wp.maxlevel + 1):
        nodes = wp.get_level(level)
        for node in nodes:
            if node.path:  # Skip root
                compressed_node = compressed_wp[node.path]
                compressed_node.data = np.where(np.abs(node.data) >= threshold, 
                                              node.data, 0)
    
    return compressed_wp.reconstruct()

compressed_signal = adaptive_compress(wp_complex, compression_ratio=0.05)
compression_error = np.sqrt(np.mean((complex_signal - compressed_signal)**2))
print(f"\nCompression (5% coefficients):")
print(f"RMSE: {compression_error:.4f}")
print(f"SNR: {20 * np.log10(np.std(complex_signal) / compression_error):.2f} dB")

Types

# Wavelet packet path specification
PacketPath = str  # Path string like '', 'a', 'd', 'aa', 'ad', 'da', 'dd', etc.

# Node ordering for level extraction
NodeOrder = Literal['natural', 'freq']

# Wavelet packet node types
WaveletPacketNode = Union[WaveletPacket, Node, Node2D, NodeND]

Install with Tessl CLI

npx tessl i tessl/pypi-pywavelets

docs

coefficient-utils.md

continuous-dwt.md

index.md

multi-level-dwt.md

multiresolution-analysis.md

single-level-dwt.md

stationary-dwt.md

thresholding.md

wavelet-packets.md

wavelets.md

tile.json