-
-
Notifications
You must be signed in to change notification settings - Fork 494
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Closes gh-676
- Loading branch information
Showing
5 changed files
with
233 additions
and
3 deletions.
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
doc/source/pyplots/cwt_wavelet_frequency_bandwidth_demo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import pywt | ||
|
||
# plot complex morlet wavelets with different center frequencies and bandwidths | ||
wavelets = [f"cmor{x:.1f}-{y:.1f}" for x in [0.5, 1.5, 2.5] for y in [0.5, 1.0, 1.5]] | ||
fig, axs = plt.subplots(3, 3, figsize=(10, 10), sharex=True, sharey=True) | ||
for ax, wavelet in zip(axs.flatten(), wavelets): | ||
[psi, x] = pywt.ContinuousWavelet(wavelet).wavefun(10) | ||
ax.plot(x, np.real(psi), label="real") | ||
ax.plot(x, np.imag(psi), label="imag") | ||
ax.set_title(wavelet) | ||
ax.set_xlim([-5, 5]) | ||
ax.set_ylim([-0.8, 1]) | ||
ax.legend() | ||
plt.suptitle("Complex Morlet Wavelets with different center frequencies and bandwidths") | ||
plt.show() | ||
|
||
|
||
def gaussian(x, x0, sigma): | ||
return np.exp(-np.power((x - x0) / sigma, 2.0) / 2.0) | ||
|
||
|
||
def make_chirp(t, t0, a): | ||
frequency = (a * (t + t0)) ** 2 | ||
chirp = np.sin(2 * np.pi * frequency * t) | ||
return chirp, frequency | ||
|
||
|
||
def plot_wavelet(time, data, wavelet, title, ax): | ||
widths = np.geomspace(1, 1024, num=75) | ||
cwtmatr, freqs = pywt.cwt( | ||
data, widths, wavelet, sampling_period=np.diff(time).mean() | ||
) | ||
cwtmatr = np.abs(cwtmatr[:-1, :-1]) | ||
pcm = ax.pcolormesh(time, freqs, cwtmatr) | ||
ax.set_yscale("log") | ||
ax.set_xlabel("Time (s)") | ||
ax.set_ylabel("Frequency (Hz)") | ||
ax.set_title(title) | ||
plt.colorbar(pcm, ax=ax) | ||
return ax | ||
|
||
|
||
# generate signal | ||
time = np.linspace(0, 1, 1000) | ||
chirp1, frequency1 = make_chirp(time, 0.2, 9) | ||
chirp2, frequency2 = make_chirp(time, 0.1, 5) | ||
chirp = chirp1 + 0.6 * chirp2 | ||
chirp *= gaussian(time, 0.5, 0.2) | ||
|
||
# perform CWT with different wavelets on same signal and plot results | ||
wavelets = [f"cmor{x:.1f}-{y:.1f}" for x in [0.5, 1.5, 2.5] for y in [0.5, 1.0, 1.5]] | ||
fig, axs = plt.subplots(3, 3, figsize=(10, 10), sharex=True) | ||
for ax, wavelet in zip(axs.flatten(), wavelets): | ||
plot_wavelet(time, chirp, wavelet, wavelet, ax) | ||
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) | ||
plt.suptitle("Scaleograms of the same signal with different wavelets") | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import pywt | ||
|
||
|
||
def gaussian(x, x0, sigma): | ||
return np.exp(-np.power((x - x0) / sigma, 2.0) / 2.0) | ||
|
||
|
||
def make_chirp(t, t0, a): | ||
frequency = (a * (t + t0)) ** 2 | ||
chirp = np.sin(2 * np.pi * frequency * t) | ||
return chirp, frequency | ||
|
||
|
||
# generate signal | ||
time = np.linspace(0, 1, 2000) | ||
chirp1, frequency1 = make_chirp(time, 0.2, 9) | ||
chirp2, frequency2 = make_chirp(time, 0.1, 5) | ||
chirp = chirp1 + 0.6 * chirp2 | ||
chirp *= gaussian(time, 0.5, 0.2) | ||
|
||
# plot signal | ||
fig, axs = plt.subplots(2, 1, sharex=True) | ||
axs[0].plot(time, chirp) | ||
axs[1].plot(time, frequency1) | ||
axs[1].plot(time, frequency2) | ||
axs[1].set_yscale("log") | ||
axs[1].set_xlabel("Time (s)") | ||
axs[0].set_ylabel("Signal") | ||
axs[1].set_ylabel("True frequency (Hz)") | ||
plt.suptitle("Input signal") | ||
plt.show() | ||
|
||
# perform CWT | ||
wavelet = "cmor1.5-1.0" | ||
# logarithmic scale for scales, as suggested by Torrence & Compo: | ||
widths = np.geomspace(1, 1024, num=100) | ||
sampling_period = np.diff(time).mean() | ||
cwtmatr, freqs = pywt.cwt(chirp, widths, wavelet, sampling_period=sampling_period) | ||
# absolute take absolute value of complex result | ||
cwtmatr = np.abs(cwtmatr[:-1, :-1]) | ||
|
||
# plot result using matplotlib's pcolormesh (image with annoted axes) | ||
fig, axs = plt.subplots(2, 1) | ||
pcm = axs[0].pcolormesh(time, freqs, cwtmatr) | ||
axs[0].set_yscale("log") | ||
axs[0].set_xlabel("Time (s)") | ||
axs[0].set_ylabel("Frequency (Hz)") | ||
axs[0].set_title("Continuous Wavelet Transform (Scaleogram)") | ||
fig.colorbar(pcm, ax=axs[0]) | ||
|
||
# plot fourier transform for comparison | ||
from numpy.fft import rfft, rfftfreq | ||
|
||
yf = rfft(chirp) | ||
xf = rfftfreq(len(chirp), sampling_period) | ||
plt.semilogx(xf, np.abs(yf)) | ||
axs[1].set_xlabel("Frequency (Hz)") | ||
axs[1].set_title("Fourier Transform") | ||
plt.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import pywt | ||
|
||
wavlist = pywt.wavelist(kind="continuous") | ||
cols = 3 | ||
rows = (len(wavlist) + cols - 1) // cols | ||
fig, axs = plt.subplots(rows, cols, figsize=(10, 10), | ||
sharex=True, sharey=True) | ||
for ax, wavelet in zip(axs.flatten(), wavlist): | ||
# A few wavelet families require parameters in the string name | ||
if wavelet in ['cmor', 'shan']: | ||
wavelet += '1-1' | ||
elif wavelet == 'fbsp': | ||
wavelet += '1-1.5-1.0' | ||
|
||
[psi, x] = pywt.ContinuousWavelet(wavelet).wavefun(10) | ||
ax.plot(x, np.real(psi), label="real") | ||
ax.plot(x, np.imag(psi), label="imag") | ||
ax.set_title(wavelet) | ||
ax.set_xlim([-5, 5]) | ||
ax.set_ylim([-0.8, 1]) | ||
|
||
ax.legend(loc="upper right") | ||
plt.suptitle("Available wavelets for CWT") | ||
plt.tight_layout() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters