diff --git a/test/Filtering/Test_FirFilter.cpp b/test/Filtering/Test_FirFilter.cpp index a6aa828..bb37936 100644 --- a/test/Filtering/Test_FirFilter.cpp +++ b/test/Filtering/Test_FirFilter.cpp @@ -24,14 +24,42 @@ template bool IsSymmetric(const SignalT& signal) { auto beg = signal.begin(); auto end = signal.rbegin(); - while (beg != end.base()) { - if (*beg != *end) { + while (beg <= end.base()) { + if (*beg != Approx(*end).margin(1e-7f)) { return false; } + ++beg; + ++end; } return true; } +template +bool IsAntiSymmetric(const SignalT& signal) { + auto beg = signal.begin(); + auto end = signal.rbegin(); + while (beg <= end.base()) { + if (*beg != Approx(-*end).margin(1e-7f)) { + return false; + } + ++beg; + ++end; + } + return true; +} + +template +auto MeasureResponse(size_t sampleRate, float frequency, const SignalT& filter) { + const float period = 1.0f / frequency; + const float length = 25.f * period; + auto testSignal = GenTestSignal(sampleRate, frequency, length); + testSignal *= BlackmanWindow(testSignal.Size()); + const auto filteredSignal = Convolution(testSignal, filter, convolution::full); + const auto rmsTest = std::sqrt(SumSquare(testSignal)); + const auto rmsFiltered = std::sqrt(SumSquare(filteredSignal)); + return rmsFiltered / rmsTest; +} + //------------------------------------------------------------------------------ // Tests @@ -46,29 +74,20 @@ TEST_CASE("Windowed Lowpass", "[FirFilter]") { static constexpr float cutoff = 3800.f; const auto normalizedCutoff = NormalizedFrequency(cutoff, sampleRate); - const auto impulse1 = FirFilter(numTaps, Lowpass(normalizedCutoff), Windowed(windows::hamming)); const auto impulse2 = FirFilter(numTaps, Lowpass(normalizedCutoff), Windowed(windows::hamming.operator()(numTaps))); + REQUIRE(IsSymmetric(impulse1)); REQUIRE(Sum(impulse1) == Approx(1)); REQUIRE(impulse1.Size() == numTaps); REQUIRE(impulse2.Size() == numTaps); REQUIRE(Max(Abs(impulse1 - impulse2)) < 1e-4f); - // Generate two signals just above and just below the cutoff and see their attenuation. - const auto passSignal = GenTestSignal(sampleRate, cutoff * 0.85f); - const auto rejectSignal = GenTestSignal(sampleRate, cutoff * 1.15f); - - const auto filteredPassSignal = Convolution(passSignal, impulse1, convolution::full); - const auto filteredRejectSignal = Convolution(rejectSignal, impulse1, convolution::full); - - const float energyPass = SumSquare(passSignal); - const float energyReject = SumSquare(rejectSignal); - const float energyFilteredPass = SumSquare(filteredPassSignal); - const float energyFilteredReject = SumSquare(filteredRejectSignal); - - REQUIRE(energyFilteredPass / energyPass > 0.95f); - REQUIRE(energyFilteredPass / energyPass < 1.05f); - REQUIRE(energyFilteredReject / energyReject < 0.05f); + const float passResponse = MeasureResponse(sampleRate, cutoff * 0.85f, impulse1); + const float stopResponse = MeasureResponse(sampleRate, cutoff * 1.15f, impulse1); + + REQUIRE(passResponse > 0.95f); + REQUIRE(passResponse < 1.05f); + REQUIRE(stopResponse < 0.05f); } @@ -92,16 +111,14 @@ TEST_CASE("Windowed arbitrary filter", "[FirFilter]") { const auto impulse1 = FirFilter(numTaps, Arbitrary(response), Windowed(windows::hamming)); const auto impulse2 = FirFilter(numTaps, Arbitrary(response), Windowed(windows::hamming.operator()(numTaps))); + REQUIRE(IsSymmetric(impulse1)); REQUIRE(impulse1.Size() == numTaps); REQUIRE(impulse2.Size() == numTaps); REQUIRE(Max(Abs(impulse1 - impulse2)) < 1e-4f); for (size_t i = 0; i < amplitudes.size(); ++i) { - const auto signal = GenTestSignal(sampleRate, frequencies[i] * sampleRate / 2.0f); - const auto filtered = Convolution(signal, impulse1, convolution::full); - const float energy = std::sqrt(SumSquare(signal)); - const float filteredEnergy = std::sqrt(SumSquare(filtered)); - REQUIRE(filteredEnergy / energy == Approx(amplitudes[i]).margin(0.05f)); + const auto response = MeasureResponse(sampleRate, frequencies[i] * sampleRate / 2.0f, impulse1); + REQUIRE(response == Approx(amplitudes[i]).margin(0.05f)); } } @@ -112,24 +129,16 @@ TEST_CASE("Highpass", "[FirFilter]") { const auto normalizedCutoff = NormalizedFrequency(cutoff, sampleRate); const auto impulse = FirFilter(numTaps, Highpass(normalizedCutoff), Windowed(windows::hamming)); + REQUIRE(IsSymmetric(impulse)); REQUIRE(Sum(impulse) < 1e-4f); REQUIRE(impulse.Size() == numTaps); - // Generate two signals just above and just below the cutoff and see their attenuation. - const auto passSignal = GenTestSignal(sampleRate, cutoff * 1.15f); - const auto rejectSignal = GenTestSignal(sampleRate, cutoff * 0.85f); - - const auto filteredPassSignal = Convolution(passSignal, impulse, convolution::full); - const auto filteredRejectSignal = Convolution(rejectSignal, impulse, convolution::full); - - const float energyPass = SumSquare(passSignal); - const float energyReject = SumSquare(rejectSignal); - const float energyFilteredPass = SumSquare(filteredPassSignal); - const float energyFilteredReject = SumSquare(filteredRejectSignal); + const float passResponse = MeasureResponse(sampleRate, cutoff * 1.15f, impulse); + const float stopResponse = MeasureResponse(sampleRate, cutoff * 0.85f, impulse); - REQUIRE(energyFilteredPass / energyPass > 0.95f); - REQUIRE(energyFilteredPass / energyPass < 1.05f); - REQUIRE(energyFilteredReject / energyReject < 0.05f); + REQUIRE(passResponse > 0.95f); + REQUIRE(passResponse < 1.05f); + REQUIRE(stopResponse < 0.05f); } @@ -141,38 +150,21 @@ TEST_CASE("Bandpass", "[FirFilter]") { const auto normalizedHigh = NormalizedFrequency(bandHigh, sampleRate); const auto impulse = FirFilter(numTaps, Bandpass(normalizedLow, normalizedHigh), Windowed(windows::hamming)); + REQUIRE(IsSymmetric(impulse)); REQUIRE(Sum(impulse) < 1e-3f); REQUIRE(impulse.Size() == numTaps); - auto extended = impulse; - extended.Resize(44100, 0.0f); - const auto spectrum = Abs(FourierTransform(extended)); - - const auto passSignal1 = GenTestSignal(sampleRate, bandLow * 1.1f); - const auto passSignal2 = GenTestSignal(sampleRate, bandHigh * 0.9f); - const auto rejectSignal1 = GenTestSignal(sampleRate, bandLow * 0.9f); - const auto rejectSignal2 = GenTestSignal(sampleRate, bandHigh * 1.1f); - - const auto filteredPassSignal1 = Convolution(passSignal1, impulse, convolution::full); - const auto filteredPassSignal2 = Convolution(passSignal2, impulse, convolution::full); - const auto filteredRejectSignal1 = Convolution(rejectSignal1, impulse, convolution::full); - const auto filteredRejectSignal2 = Convolution(rejectSignal2, impulse, convolution::full); - - const float energyPass1 = SumSquare(passSignal1); - const float energyPass2 = SumSquare(passSignal2); - const float energyReject1 = SumSquare(rejectSignal1); - const float energyReject2 = SumSquare(rejectSignal2); - const float energyFilteredPass1 = SumSquare(filteredPassSignal1); - const float energyFilteredPass2 = SumSquare(filteredPassSignal2); - const float energyFilteredReject1 = SumSquare(filteredRejectSignal1); - const float energyFilteredReject2 = SumSquare(filteredRejectSignal2); - - REQUIRE(energyFilteredPass1 / energyPass1 > 0.95f); - REQUIRE(energyFilteredPass1 / energyPass1 < 1.05f); - REQUIRE(energyFilteredReject1 / energyReject1 < 0.05f); - REQUIRE(energyFilteredPass2 / energyPass2 > 0.95f); - REQUIRE(energyFilteredPass2 / energyPass2 < 1.05f); - REQUIRE(energyFilteredReject2 / energyReject2 < 0.05f); + const float lowStopResponse = MeasureResponse(sampleRate, bandLow * 0.9f, impulse); + const float lowPassResponse = MeasureResponse(sampleRate, bandLow * 1.1f, impulse); + const float highPassResponse = MeasureResponse(sampleRate, bandHigh * 0.9f, impulse); + const float highStopResponse = MeasureResponse(sampleRate, bandHigh * 1.1f, impulse); + + REQUIRE(highPassResponse > 0.95f); + REQUIRE(highPassResponse < 1.05f); + REQUIRE(highStopResponse < 0.05f); + REQUIRE(lowPassResponse > 0.95f); + REQUIRE(lowPassResponse < 1.05f); + REQUIRE(lowStopResponse < 0.05f); } @@ -184,38 +176,21 @@ TEST_CASE("Bandstop", "[FirFilter]") { const auto normalizedHigh = NormalizedFrequency(bandHigh, sampleRate); const auto impulse = FirFilter(numTaps, Bandstop(normalizedLow, normalizedHigh), Windowed(windows::hamming)); + REQUIRE(IsSymmetric(impulse)); REQUIRE(Sum(impulse) == Approx(1).epsilon(0.005f)); REQUIRE(impulse.Size() == numTaps); - auto extended = impulse; - extended.Resize(44100, 0.0f); - const auto spectrum = Abs(FourierTransform(extended)); - - const auto rejectSignal1 = GenTestSignal(sampleRate, bandLow * 1.1f); - const auto rejectSignal2 = GenTestSignal(sampleRate, bandHigh * 0.9f); - const auto passSignal1 = GenTestSignal(sampleRate, bandLow * 0.9f); - const auto passSignal2 = GenTestSignal(sampleRate, bandHigh * 1.1f); - - const auto filteredPassSignal1 = Convolution(passSignal1, impulse, convolution::full); - const auto filteredPassSignal2 = Convolution(passSignal2, impulse, convolution::full); - const auto filteredRejectSignal1 = Convolution(rejectSignal1, impulse, convolution::full); - const auto filteredRejectSignal2 = Convolution(rejectSignal2, impulse, convolution::full); - - const float energyPass1 = SumSquare(passSignal1); - const float energyPass2 = SumSquare(passSignal2); - const float energyReject1 = SumSquare(rejectSignal1); - const float energyReject2 = SumSquare(rejectSignal2); - const float energyFilteredPass1 = SumSquare(filteredPassSignal1); - const float energyFilteredPass2 = SumSquare(filteredPassSignal2); - const float energyFilteredReject1 = SumSquare(filteredRejectSignal1); - const float energyFilteredReject2 = SumSquare(filteredRejectSignal2); - - REQUIRE(energyFilteredPass1 / energyPass1 > 0.95f); - REQUIRE(energyFilteredPass1 / energyPass1 < 1.05f); - REQUIRE(energyFilteredReject1 / energyReject1 < 0.05f); - REQUIRE(energyFilteredPass2 / energyPass2 > 0.95f); - REQUIRE(energyFilteredPass2 / energyPass2 < 1.05f); - REQUIRE(energyFilteredReject2 / energyReject2 < 0.05f); + const float lowPassResponse = MeasureResponse(sampleRate, bandLow * 0.9f, impulse); + const float lowStopResponse = MeasureResponse(sampleRate, bandLow * 1.1f, impulse); + const float highStopResponse = MeasureResponse(sampleRate, bandHigh * 0.9f, impulse); + const float highPassResponse = MeasureResponse(sampleRate, bandHigh * 1.1f, impulse); + + REQUIRE(highPassResponse > 0.95f); + REQUIRE(highPassResponse < 1.05f); + REQUIRE(highStopResponse < 0.05f); + REQUIRE(lowPassResponse > 0.95f); + REQUIRE(lowPassResponse < 1.05f); + REQUIRE(lowStopResponse < 0.05f); } @@ -223,6 +198,7 @@ TEST_CASE("Bandstop", "[FirFilter]") { TEST_CASE("Hilbert odd form", "[Hilbert]") { const auto filter = FirFilter(247, Hilbert(), Windowed(windows::hamming)); REQUIRE(filter.Size() == 247); + REQUIRE(IsAntiSymmetric(filter)); const auto nonZeroSamples = Decimate(filter, 2); const auto zeroSamples = Decimate(AsView(filter).SubSignal(1), 2); REQUIRE(Max(zeroSamples) == 0.0f); @@ -236,6 +212,7 @@ TEST_CASE("Hilbert odd form", "[Hilbert]") { TEST_CASE("Hilbert even form", "[Hilbert]") { const auto filter = FirFilter(246, Hilbert(), Windowed(windows::hamming)); REQUIRE(filter.Size() == 246); + REQUIRE(IsAntiSymmetric(filter)); REQUIRE(Min(Abs(filter)) > 0.0f); const auto firstHalf = AsView(filter).SubSignal(0, filter.Size() / 2); const auto secondHalf = AsView(filter).SubSignal(filter.Size() / 2); @@ -246,6 +223,7 @@ TEST_CASE("Hilbert even form", "[Hilbert]") { TEST_CASE("Hilbert odd small form", "[Hilbert]") { const auto filter = FirFilter(19, Hilbert(), Windowed(windows::hamming)); REQUIRE(filter.Size() == 19); + REQUIRE(IsAntiSymmetric(filter)); const auto nonZeroSamples = Decimate(filter, 2); const auto zeroSamples = Decimate(AsView(filter).SubSignal(1), 2); REQUIRE(Max(zeroSamples) == 0.0f); @@ -259,6 +237,7 @@ TEST_CASE("Hilbert odd small form", "[Hilbert]") { TEST_CASE("Hilbert even small form", "[Hilbert]") { const auto filter = FirFilter(10, Hilbert(), Windowed(windows::hamming)); REQUIRE(filter.Size() == 10); + REQUIRE(IsAntiSymmetric(filter)); REQUIRE(Min(Abs(filter)) > 0.0f); const auto firstHalf = AsView(filter).SubSignal(0, filter.Size() / 2); const auto secondHalf = AsView(filter).SubSignal(filter.Size() / 2);