Discrete and continuous wavelet transforms for signal and image processing with comprehensive 1D, 2D, and nD transform support.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Continuous wavelet transform (CWT) for time-frequency analysis providing detailed spectral analysis with adjustable time-frequency resolution for non-stationary signal analysis.
Time-frequency decomposition using continuous wavelets with arbitrary scales.
def cwt(data, scales, wavelet, sampling_period: float = 1.0,
method: str = 'conv', axis: int = -1, *, precision: int = 12):
"""
1D continuous wavelet transform.
Parameters:
- data: Input 1D signal
- scales: Array of scales for the transform
- wavelet: Continuous wavelet name or ContinuousWavelet object
- sampling_period: Sampling period for frequency calculation (default: 1.0)
- method: Computation method ('conv' for convolution, 'fft' for FFT-based)
- axis: Axis along which to perform CWT (default: -1)
- precision: Precision for wavelet function approximation (default: 12)
Returns:
(coefficients, frequencies) tuple where:
- coefficients: 2D array (len(scales), len(data)) of CWT coefficients
- frequencies: Array of frequencies corresponding to scales
"""import pywt
import numpy as np
import matplotlib.pyplot as plt
# Create test signal with time-varying frequency
dt = 0.01
t = np.arange(0, 10, dt)
freq1, freq2 = 2, 10
# Chirp signal: frequency increases linearly with time
signal = np.sin(2 * np.pi * (freq1 + (freq2 - freq1) * t / max(t)) * t)
# Add a transient high-frequency component
transient_start, transient_end = 300, 400 # Sample indices
signal[transient_start:transient_end] += 2 * np.sin(2 * np.pi * 50 * t[transient_start:transient_end])
# Add noise
noise_level = 0.3
noisy_signal = signal + noise_level * np.random.randn(len(signal))
print(f"Signal length: {len(noisy_signal)}")
print(f"Sampling period: {dt}")
print(f"Total duration: {t[-1]:.2f} seconds")
# Define scales for CWT
scales = np.arange(1, 128) # Logarithmic scale often better
# scales = np.logspace(0, 2, 64) # Alternative: logarithmic spacing
# Perform CWT using different wavelets
wavelets = ['mexh', 'morl', 'cgau8']
cwt_results = {}
for wavelet_name in wavelets:
coefficients, frequencies = pywt.cwt(noisy_signal, scales, wavelet_name, sampling_period=dt)
cwt_results[wavelet_name] = (coefficients, frequencies)
print(f"{wavelet_name}: CWT shape {coefficients.shape}, frequency range {frequencies.min():.3f}-{frequencies.max():.3f} Hz")
# Visualization
fig, axes = plt.subplots(len(wavelets) + 1, 1, figsize=(15, 12))
# Original signal
axes[0].plot(t, signal, 'b-', label='Clean signal', alpha=0.7)
axes[0].plot(t, noisy_signal, 'r-', label='Noisy signal', alpha=0.8)
axes[0].set_title('Original Signal')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Amplitude')
axes[0].legend()
axes[0].grid(True)
# CWT plots for each wavelet
for i, wavelet_name in enumerate(wavelets):
coefficients, frequencies = cwt_results[wavelet_name]
# Create time-frequency plot
T, F = np.meshgrid(t, frequencies)
im = axes[i+1].contourf(T, F, np.abs(coefficients), levels=50, cmap='jet')
axes[i+1].set_title(f'CWT Scalogram - {wavelet_name.upper()}')
axes[i+1].set_xlabel('Time (s)')
axes[i+1].set_ylabel('Frequency (Hz)')
plt.colorbar(im, ax=axes[i+1], label='|CWT Coefficients|')
plt.tight_layout()
plt.show()
# Advanced analysis: Ridge detection for instantaneous frequency
def detect_ridges(coefficients, frequencies, threshold_factor=0.5):
"""Detect ridges in CWT scalogram for instantaneous frequency estimation."""
abs_coeffs = np.abs(coefficients)
threshold = threshold_factor * np.max(abs_coeffs)
ridges = []
for t_idx in range(abs_coeffs.shape[1]):
column = abs_coeffs[:, t_idx]
peaks = []
for f_idx in range(1, len(column)-1):
if (column[f_idx] > column[f_idx-1] and
column[f_idx] > column[f_idx+1] and
column[f_idx] > threshold):
peaks.append((f_idx, column[f_idx]))
if peaks:
# Take the strongest peak
strongest_peak = max(peaks, key=lambda x: x[1])
ridges.append(frequencies[strongest_peak[0]])
else:
ridges.append(np.nan)
return np.array(ridges)
# Detect instantaneous frequency using Morlet wavelet
coefficients, frequencies = cwt_results['morl']
inst_freq = detect_ridges(coefficients, frequencies)
# Plot instantaneous frequency
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(t, noisy_signal)
plt.title('Signal')
plt.ylabel('Amplitude')
plt.subplot(2, 1, 2)
plt.plot(t, inst_freq, 'r-', linewidth=2, label='Estimated Instantaneous Frequency')
# True instantaneous frequency for the chirp
true_inst_freq = freq1 + (freq2 - freq1) * t / max(t)
plt.plot(t, true_inst_freq, 'b--', linewidth=2, label='True Instantaneous Frequency')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.legend()
plt.title('Instantaneous Frequency Estimation')
plt.grid(True)
plt.tight_layout()
plt.show()
# CWT-based denoising using thresholding
def cwt_denoise(signal, scales, wavelet, threshold_factor=0.1):
"""Denoise signal using CWT thresholding."""
coefficients, frequencies = pywt.cwt(signal, scales, wavelet)
# Threshold coefficients
threshold = threshold_factor * np.max(np.abs(coefficients))
coefficients_thresh = np.where(np.abs(coefficients) > threshold, coefficients, 0)
# Inverse CWT (approximate reconstruction using mother wavelet)
# Note: Exact inverse CWT requires admissibility constant
reconstruction = np.zeros_like(signal)
wavelet_obj = pywt.ContinuousWavelet(wavelet)
_, psi = wavelet_obj.wavefun()
for i, scale in enumerate(scales):
# Simple reconstruction (not exact inverse CWT)
contribution = np.real(coefficients_thresh[i, :])
reconstruction += contribution / scale
return reconstruction / len(scales)
# Apply CWT denoising
denoised_cwt = cwt_denoise(noisy_signal, scales[:32], 'morl', threshold_factor=0.2)
plt.figure(figsize=(12, 6))
plt.plot(t, signal, 'b-', label='Original clean signal', linewidth=2)
plt.plot(t, noisy_signal, 'r-', alpha=0.7, label='Noisy signal')
plt.plot(t, denoised_cwt, 'g-', label='CWT denoised', linewidth=2)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('CWT-based Denoising')
plt.legend()
plt.grid(True)
plt.show()
# SNR calculation
def calculate_snr(clean, noisy):
"""Calculate Signal-to-Noise Ratio in dB."""
signal_power = np.mean(clean**2)
noise_power = np.mean((noisy - clean)**2)
return 10 * np.log10(signal_power / noise_power)
original_snr = calculate_snr(signal, noisy_signal)
denoised_snr = calculate_snr(signal, denoised_cwt)
print(f"Original SNR: {original_snr:.2f} dB")
print(f"Denoised SNR: {denoised_snr:.2f} dB")
print(f"SNR improvement: {denoised_snr - original_snr:.2f} dB")Utility functions for converting between CWT scales and frequencies.
def scale2frequency(wavelet, scale: float, precision: int = 8) -> float:
"""
Convert CWT scale to normalized frequency.
Parameters:
- wavelet: Continuous wavelet specification
- scale: CWT scale value
- precision: Precision for wavelet function approximation
Returns:
Normalized frequency (sampling frequency = 1.0)
"""
def frequency2scale(wavelet, freq: float, precision: int = 8) -> float:
"""
Convert normalized frequency to CWT scale.
Parameters:
- wavelet: Continuous wavelet specification
- freq: Normalized frequency
- precision: Precision for wavelet function approximation
Returns:
CWT scale value
"""
def next_fast_len(n: int) -> int:
"""
Round up size to the nearest power of two for FFT optimization.
Parameters:
- n: Input size
Returns:
Next power of two >= n for efficient FFT computation
"""import pywt
import numpy as np
import matplotlib.pyplot as plt
# Scale-frequency relationship analysis
wavelets = ['mexh', 'morl', 'cgau8', 'cmor1.5-1.0']
scales = np.arange(1, 101)
plt.figure(figsize=(12, 8))
for wavelet_name in wavelets:
try:
frequencies = []
for scale in scales:
freq = pywt.scale2frequency(wavelet_name, scale)
frequencies.append(freq)
plt.loglog(scales, frequencies, 'o-', label=wavelet_name, markersize=3)
# Show inverse relationship
print(f"{wavelet_name}: Scale 10 -> Frequency {pywt.scale2frequency(wavelet_name, 10):.4f}")
freq_01 = 0.1
scale_back = pywt.frequency2scale(wavelet_name, freq_01)
print(f"{wavelet_name}: Frequency 0.1 -> Scale {scale_back:.4f}")
except Exception as e:
print(f"Error with {wavelet_name}: {e}")
plt.xlabel('Scale')
plt.ylabel('Normalized Frequency')
plt.title('Scale-Frequency Relationship for Different Wavelets')
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.3)
plt.show()
# Practical example: Design CWT analysis for specific frequency range
target_freq_range = [1, 50] # Hz
sampling_freq = 1000 # Hz
sampling_period = 1.0 / sampling_freq
# Convert to normalized frequencies
norm_freq_range = [f * sampling_period for f in target_freq_range]
print(f"Target frequency range: {target_freq_range} Hz")
print(f"Normalized frequency range: {norm_freq_range}")
# Find corresponding scales for Morlet wavelet
wavelet = 'morl'
scales_for_range = []
for norm_freq in norm_freq_range:
scale = pywt.frequency2scale(wavelet, norm_freq)
scales_for_range.append(scale)
print(f"Corresponding scales for {wavelet}: {scales_for_range}")
# Create logarithmically spaced scales covering the range
num_scales = 64
scale_min, scale_max = min(scales_for_range), max(scales_for_range)
scales_log = np.logspace(np.log10(scale_min), np.log10(scale_max), num_scales)
# Verify frequency coverage
freq_coverage = [pywt.scale2frequency(wavelet, s) / sampling_period for s in scales_log]
print(f"Frequency coverage: {min(freq_coverage):.2f} - {max(freq_coverage):.2f} Hz")
# Test with synthetic signal in target range
t = np.arange(0, 2, sampling_period)
test_signal = (np.sin(2 * np.pi * 5 * t) + # 5 Hz
np.sin(2 * np.pi * 25 * t)) # 25 Hz
coefficients, frequencies = pywt.cwt(test_signal, scales_log, wavelet, sampling_period=sampling_period)
# Plot result
plt.figure(figsize=(12, 6))
T, F = np.meshgrid(t, frequencies)
plt.contourf(T, F, np.abs(coefficients), levels=50, cmap='jet')
plt.colorbar(label='|CWT Coefficients|')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.title('CWT Analysis in Target Frequency Range')
plt.ylim(target_freq_range)
plt.show()Working with complex-valued continuous wavelets for phase and amplitude analysis.
import pywt
import numpy as np
import matplotlib.pyplot as plt
# Create test signal with amplitude and frequency modulation
t = np.linspace(0, 4, 1000)
dt = t[1] - t[0]
# Amplitude modulated signal
carrier_freq = 10 # Hz
mod_freq = 1 # Hz
am_signal = (1 + 0.5 * np.sin(2 * np.pi * mod_freq * t)) * np.sin(2 * np.pi * carrier_freq * t)
# Frequency modulated signal
freq_deviation = 5 # Hz
fm_signal = np.sin(2 * np.pi * (carrier_freq + freq_deviation * np.sin(2 * np.pi * mod_freq * t)) * t)
# Combine signals in different time windows
signal = np.concatenate([am_signal[:250], fm_signal[250:500],
am_signal[500:750], fm_signal[750:]])
print(f"Signal length: {len(signal)}")
# CWT with complex Morlet wavelet
scales = np.arange(1, 64)
coefficients, frequencies = pywt.cwt(signal, scales, 'morl', sampling_period=dt)
print(f"CWT coefficients are complex: {np.iscomplexobj(coefficients)}")
print(f"CWT shape: {coefficients.shape}")
# Extract amplitude and phase
amplitude = np.abs(coefficients)
phase = np.angle(coefficients)
instantaneous_phase = np.unwrap(phase, axis=1)
# Plot complex CWT analysis
fig, axes = plt.subplots(4, 1, figsize=(15, 12))
# Original signal
axes[0].plot(t, signal)
axes[0].set_title('Test Signal (AM + FM)')
axes[0].set_ylabel('Amplitude')
axes[0].grid(True)
# Amplitude scalogram
T, F = np.meshgrid(t, frequencies)
im1 = axes[1].contourf(T, F, amplitude, levels=50, cmap='jet')
axes[1].set_title('CWT Amplitude Scalogram')
axes[1].set_ylabel('Frequency (Hz)')
plt.colorbar(im1, ax=axes[1], label='Amplitude')
# Phase scalogram
im2 = axes[2].contourf(T, F, phase, levels=50, cmap='hsv')
axes[2].set_title('CWT Phase Scalogram')
axes[2].set_ylabel('Frequency (Hz)')
plt.colorbar(im2, ax=axes[2], label='Phase (rad)')
# Instantaneous frequency (derivative of phase)
# Focus on carrier frequency region
carrier_idx = np.argmin(np.abs(frequencies - carrier_freq))
inst_freq_approx = np.diff(instantaneous_phase[carrier_idx, :]) / (2 * np.pi * dt)
axes[3].plot(t[:-1], inst_freq_approx, 'r-', linewidth=2, label='Estimated Inst. Freq.')
axes[3].axhline(y=carrier_freq, color='b', linestyle='--', label=f'Carrier ({carrier_freq} Hz)')
axes[3].set_title('Instantaneous Frequency Estimation')
axes[3].set_xlabel('Time (s)')
axes[3].set_ylabel('Frequency (Hz)')
axes[3].legend()
axes[3].grid(True)
plt.tight_layout()
plt.show()
# Amplitude and phase tracking at specific frequency
target_freq = carrier_freq
freq_idx = np.argmin(np.abs(frequencies - target_freq))
amplitude_track = amplitude[freq_idx, :]
phase_track = phase[freq_idx, :]
plt.figure(figsize=(12, 8))
plt.subplot(3, 1, 1)
plt.plot(t, signal)
plt.title(f'Original Signal')
plt.ylabel('Amplitude')
plt.grid(True)
plt.subplot(3, 1, 2)
plt.plot(t, amplitude_track, 'r-', linewidth=2)
plt.title(f'Amplitude Tracking at {target_freq} Hz')
plt.ylabel('CWT Amplitude')
plt.grid(True)
plt.subplot(3, 1, 3)
plt.plot(t, phase_track, 'g-', linewidth=2)
plt.title(f'Phase Tracking at {target_freq} Hz')
plt.xlabel('Time (s)')
plt.ylabel('Phase (rad)')
plt.grid(True)
plt.tight_layout()
plt.show()
# Demonstrate admissibility of different wavelets
wavelets_to_test = ['mexh', 'morl', 'cgau8', 'cmor1.5-1.0']
for wavelet_name in wavelets_to_test:
try:
wavelet_obj = pywt.ContinuousWavelet(wavelet_name)
print(f"\n{wavelet_name.upper()}:")
print(f" Complex CWT: {wavelet_obj.complex_cwt}")
print(f" Center frequency: {wavelet_obj.center_frequency:.4f}")
print(f" Bandwidth frequency: {wavelet_obj.bandwidth_frequency:.4f}")
print(f" Support: [{wavelet_obj.lower_bound:.2f}, {wavelet_obj.upper_bound:.2f}]")
# Show wavelet function
psi, x = wavelet_obj.wavefun()
if wavelet_obj.complex_cwt:
print(f" Real part range: [{np.min(np.real(psi)):.4f}, {np.max(np.real(psi)):.4f}]")
print(f" Imag part range: [{np.min(np.imag(psi)):.4f}, {np.max(np.imag(psi)):.4f}]")
else:
print(f" Value range: [{np.min(psi):.4f}, {np.max(psi):.4f}]")
except Exception as e:
print(f"Error with {wavelet_name}: {e}")# CWT result format
CWTResult = Tuple[np.ndarray, np.ndarray] # (coefficients, frequencies)
# CWT coefficients are 2D: (len(scales), len(data))
# Can be real or complex depending on wavelet
CWTCoefficients = Union[np.ndarray[np.float64], np.ndarray[np.complex128]]
# Scales specification
Scales = Union[np.ndarray, list, range]
# CWT methods
CWTMethod = Literal['conv', 'fft']
# Common continuous wavelets
ContinuousWaveletName = Literal[
'mexh', # Mexican hat
'morl', # Morlet
'cgau1', 'cgau2', 'cgau3', 'cgau4', 'cgau5', 'cgau6', 'cgau7', 'cgau8', # Complex Gaussian
'shan', # Shannon
'fbsp', # Frequency B-Spline
'cmor', # Complex Morlet (cmor<bandwidth>-<center_frequency>)
'gaus1', 'gaus2', 'gaus3', 'gaus4', 'gaus5', 'gaus6', 'gaus7', 'gaus8' # Gaussian derivatives
]Install with Tessl CLI
npx tessl i tessl/pypi-pywavelets