Discrete and continuous wavelet transforms for signal and image processing with comprehensive 1D, 2D, and nD transform support.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Stationary (undecimated) wavelet transforms providing translation-invariant analysis with no downsampling, preserving all frequency information at each decomposition level.
Translation-invariant decomposition and reconstruction for one-dimensional signals.
def swt(data, wavelet, level: int = None, start_level: int = 0, axis: int = -1,
trim_approx: bool = False, norm: bool = False):
"""
1D stationary wavelet transform.
Parameters:
- data: Input 1D array (length must be divisible by 2^level)
- wavelet: Wavelet specification
- level: Number of decomposition levels (default: maximum possible)
- start_level: Starting level (default: 0)
- axis: Axis along which to perform SWT (default: -1)
- trim_approx: If True, return only detail coefficients
- norm: If True, normalize coefficients
Returns:
List of (cA, cD) tuples for each level if trim_approx=False
List of cD arrays for each level if trim_approx=True
"""
def iswt(coeffs, wavelet, norm: bool = False, axis: int = -1):
"""
1D inverse stationary wavelet transform.
Parameters:
- coeffs: List of (cA, cD) tuples from swt or list of cD arrays
- wavelet: Wavelet specification matching forward transform
- norm: Normalization flag matching forward transform
- axis: Axis along which to perform ISWT
Returns:
Reconstructed 1D signal
"""import pywt
import numpy as np
import matplotlib.pyplot as plt
# Create test signal (length must be divisible by 2^level)
n = 1024 # 2^10
t = np.linspace(0, 1, n)
signal = (np.sin(2 * np.pi * 5 * t) + # Low frequency
0.5 * np.sin(2 * np.pi * 20 * t) + # Medium frequency
0.3 * np.cos(2 * np.pi * 100 * t)) # High frequency
noise = 0.2 * np.random.randn(n)
noisy_signal = signal + noise
# Stationary wavelet transform
level = 6
coeffs = pywt.swt(noisy_signal, 'db4', level=level)
print(f"Number of decomposition levels: {len(coeffs)}")
print(f"Each coefficient array length: {len(coeffs[0][0])}") # Same as input length
# Access coefficients
for i, (cA, cD) in enumerate(coeffs):
print(f"Level {i+1}: Approximation shape {cA.shape}, Detail shape {cD.shape}")
# Perfect reconstruction
reconstructed = pywt.iswt(coeffs, 'db4')
print(f"Reconstruction error: {np.max(np.abs(noisy_signal - reconstructed))}")
# SWT preserves signal length at all levels - good for analysis
# Compare with regular DWT
dwt_coeffs = pywt.wavedec(noisy_signal, 'db4', level=level)
print(f"DWT coefficient lengths: {[len(c) for c in dwt_coeffs]}")
print(f"SWT coefficient lengths: {[len(coeffs[i][0]) for i in range(level)]}")
# Translation invariance demonstration
shifted_signal = np.roll(noisy_signal, 100) # Circular shift
coeffs_shifted = pywt.swt(shifted_signal, 'db4', level=level)
# Compare detail coefficients at level 3
detail_orig = coeffs[2][1] # Detail at level 3
detail_shifted = coeffs_shifted[2][1] # Detail at level 3 for shifted signal
# SWT coefficients are also shifted (translation invariance)
correlation = np.correlate(detail_orig, detail_shifted, mode='full')
shift_detected = np.argmax(correlation) - len(detail_orig) + 1
print(f"Detected shift in coefficients: {shift_detected % n}")
# Visualization
fig, axes = plt.subplots(3, 2, figsize=(15, 12))
axes[0, 0].plot(t, signal, 'b-', label='Clean')
axes[0, 0].plot(t, noisy_signal, 'r-', alpha=0.7, label='Noisy')
axes[0, 0].set_title('Original vs Noisy Signal')
axes[0, 0].legend()
axes[0, 1].plot(t, reconstructed, 'g-', label='Reconstructed')
axes[0, 1].plot(t, noisy_signal, 'r--', alpha=0.7, label='Original')
axes[0, 1].set_title('SWT Reconstruction')
axes[0, 1].legend()
# Plot details at different levels
for i in range(4):
row = (i // 2) + 1
col = i % 2
if i < len(coeffs):
cA, cD = coeffs[i]
axes[row, col].plot(t, cD)
axes[row, col].set_title(f'Detail Level {i+1}')
plt.tight_layout()
plt.show()Translation-invariant decomposition and reconstruction for two-dimensional data.
def swt2(data, wavelet, level: int, start_level: int = 0, axes=(-2, -1),
trim_approx: bool = False, norm: bool = False):
"""
2D stationary wavelet transform.
Parameters:
- data: Input 2D array (dimensions must be divisible by 2^level)
- wavelet: Wavelet specification
- level: Number of decomposition levels (required)
- start_level: Starting level (default: 0)
- axes: Pair of axes for 2D transform (default: last two axes)
- trim_approx: If True, return only detail coefficients
- norm: If True, normalize coefficients
Returns:
List of (cA, (cH, cV, cD)) tuples for each level if trim_approx=False
List of (cH, cV, cD) tuples for each level if trim_approx=True
"""
def iswt2(coeffs, wavelet, norm: bool = False, axes=(-2, -1)):
"""
2D inverse stationary wavelet transform.
Parameters:
- coeffs: List of coefficient tuples from swt2
- wavelet: Wavelet specification matching forward transform
- norm: Normalization flag matching forward transform
- axes: Pair of axes for 2D transform
Returns:
Reconstructed 2D array
"""import pywt
import numpy as np
import matplotlib.pyplot as plt
# Create test image (dimensions must be divisible by 2^level)
size = 256 # 2^8
x, y = np.mgrid[0:size, 0:size]
image = (np.sin(2 * np.pi * x / 64) * np.cos(2 * np.pi * y / 64) +
0.5 * np.sin(2 * np.pi * x / 16) * np.cos(2 * np.pi * y / 16))
noise = 0.3 * np.random.randn(size, size)
noisy_image = image + noise
print(f"Image shape: {noisy_image.shape}")
# 2D stationary wavelet transform
level = 4
coeffs = pywt.swt2(noisy_image, 'db2', level=level)
print(f"Number of decomposition levels: {len(coeffs)}")
# Each level preserves the original image size
for i, (cA, (cH, cV, cD)) in enumerate(coeffs):
print(f"Level {i+1}: All coefficients shape {cA.shape}")
# Perfect reconstruction
reconstructed = pywt.iswt2(coeffs, 'db2')
print(f"2D SWT reconstruction error: {np.max(np.abs(noisy_image - reconstructed))}")
# Translation invariance in 2D
shifted_image = np.roll(np.roll(noisy_image, 50, axis=0), 30, axis=1)
coeffs_shifted = pywt.swt2(shifted_image, 'db2', level=level)
# Image denoising using SWT
def swt_denoise(image, wavelet, level, threshold):
"""Denoise image using stationary wavelet transform."""
coeffs = pywt.swt2(image, wavelet, level=level)
# Threshold detail coefficients at all levels
coeffs_thresh = []
for cA, (cH, cV, cD) in coeffs:
cH_thresh = pywt.threshold(cH, threshold, mode='soft')
cV_thresh = pywt.threshold(cV, threshold, mode='soft')
cD_thresh = pywt.threshold(cD, threshold, mode='soft')
coeffs_thresh.append((cA, (cH_thresh, cV_thresh, cD_thresh)))
return pywt.iswt2(coeffs_thresh, wavelet)
# Apply denoising
denoised_image = swt_denoise(noisy_image, 'db4', level=3, threshold=0.1)
# Visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes[0, 0].imshow(image, cmap='gray')
axes[0, 0].set_title('Original Clean Image')
axes[0, 0].axis('off')
axes[0, 1].imshow(noisy_image, cmap='gray')
axes[0, 1].set_title('Noisy Image')
axes[0, 1].axis('off')
axes[0, 2].imshow(denoised_image, cmap='gray')
axes[0, 2].set_title('SWT Denoised')
axes[0, 2].axis('off')
# Show coefficients at finest level
cA_1, (cH_1, cV_1, cD_1) = coeffs[-1] # Level 1 (finest)
axes[1, 0].imshow(np.abs(cH_1), cmap='gray')
axes[1, 0].set_title('Horizontal Details (Level 1)')
axes[1, 0].axis('off')
axes[1, 1].imshow(np.abs(cV_1), cmap='gray')
axes[1, 1].set_title('Vertical Details (Level 1)')
axes[1, 1].axis('off')
axes[1, 2].imshow(np.abs(cD_1), cmap='gray')
axes[1, 2].set_title('Diagonal Details (Level 1)')
axes[1, 2].axis('off')
plt.tight_layout()
plt.show()
# Compare with regular DWT for edge preservation
regular_dwt = pywt.wavedec2(noisy_image, 'db4', level=3)
regular_dwt_thresh = [regular_dwt[0]] # Keep approximation
for cH, cV, cD in regular_dwt[1:]:
cH_t = pywt.threshold(cH, 0.1, mode='soft')
cV_t = pywt.threshold(cV, 0.1, mode='soft')
cD_t = pywt.threshold(cD, 0.1, mode='soft')
regular_dwt_thresh.append((cH_t, cV_t, cD_t))
regular_denoised = pywt.waverec2(regular_dwt_thresh, 'db4')
# Edge detection to compare preservation
def edge_strength(img):
"""Simple edge strength measure."""
gx = np.diff(img, axis=1)
gy = np.diff(img, axis=0)
return np.mean(np.sqrt(gx[:-1,:]**2 + gy[:,:-1]**2))
print(f"Edge strength - Original: {edge_strength(image):.4f}")
print(f"Edge strength - SWT denoised: {edge_strength(denoised_image):.4f}")
print(f"Edge strength - DWT denoised: {edge_strength(regular_denoised):.4f}")Translation-invariant decomposition and reconstruction for n-dimensional data.
def swtn(data, wavelet, level: int, start_level: int = 0, axes=None,
trim_approx: bool = False, norm: bool = False):
"""
nD stationary wavelet transform.
Parameters:
- data: Input nD array (all dimensions must be divisible by 2^level)
- wavelet: Wavelet specification
- level: Number of decomposition levels (required)
- start_level: Starting level (default: 0)
- axes: Axes along which to perform transform (default: all axes)
- trim_approx: If True, return only detail coefficients
- norm: If True, normalize coefficients
Returns:
List of coefficient dictionaries for each level
"""
def iswtn(coeffs, wavelet, axes=None, norm: bool = False):
"""
nD inverse stationary wavelet transform.
Parameters:
- coeffs: List of coefficient dictionaries from swtn
- wavelet: Wavelet specification matching forward transform
- axes: Axes along which to perform transform (should match forward)
- norm: Normalization flag matching forward transform
Returns:
Reconstructed nD array
"""def swt_max_level(input_len: int) -> int:
"""
Compute maximum SWT decomposition level.
Parameters:
- input_len: Length of input signal
Returns:
Maximum level such that input_len is divisible by 2^level
"""import pywt
import numpy as np
# 3D volume processing with SWT
volume = np.random.randn(64, 64, 64) # Must be powers of 2
print(f"3D volume shape: {volume.shape}")
# Check maximum level
max_level = pywt.swt_max_level(64) # 64 = 2^6, so max level is 6
print(f"Maximum SWT level for size 64: {max_level}")
# 3D SWT
level = 3
coeffs_3d = pywt.swtn(volume, 'haar', level=level)
print(f"Number of 3D SWT levels: {len(coeffs_3d)}")
# Each level has same size as input
for i, coeff_dict in enumerate(coeffs_3d):
print(f"Level {i+1} coefficients:")
for key, coeff in coeff_dict.items():
print(f" '{key}': {coeff.shape}")
# Perfect reconstruction
reconstructed_3d = pywt.iswtn(coeffs_3d, 'haar')
print(f"3D SWT reconstruction error: {np.max(np.abs(volume - reconstructed_3d))}")
# Example: 1D signal length analysis
for length in [128, 256, 512, 1000]:
max_level = pywt.swt_max_level(length)
print(f"Length {length}: max SWT level = {max_level}")
# SWT with trimmed approximation (details only)
signal_1d = np.random.randn(512)
details_only = pywt.swt(signal_1d, 'db2', level=4, trim_approx=True)
print(f"Details only - number of arrays: {len(details_only)}")
print(f"Each detail array shape: {details_only[0].shape}")
# Reconstruction from details only requires adding zero approximation
zero_approx = np.zeros_like(details_only[0])
coeffs_for_recon = [(zero_approx, detail) for detail in details_only]
reconstructed_details = pywt.iswt(coeffs_for_recon, 'db2')
print(f"Reconstruction from details only shape: {reconstructed_details.shape}")# SWT coefficient formats
SWTCoeffs1D = List[Tuple[np.ndarray, np.ndarray]] # List of (cA, cD) tuples
SWTCoeffs2D = List[Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]] # List of (cA, (cH, cV, cD))
SWTCoeffsND = List[Dict[str, np.ndarray]] # List of coefficient dictionaries
# Trimmed approximation formats (details only)
SWTDetails1D = List[np.ndarray] # List of detail arrays
SWTDetails2D = List[Tuple[np.ndarray, np.ndarray, np.ndarray]] # List of (cH, cV, cD) tuples
SWTDetailsND = List[Dict[str, np.ndarray]] # List of detail dictionaries (excluding approximation key)Install with Tessl CLI
npx tessl i tessl/pypi-pywavelets