diff --git a/include/dspbb/Filtering/WindowFunctions.hpp b/include/dspbb/Filtering/WindowFunctions.hpp index 1580295..10fff7d 100644 --- a/include/dspbb/Filtering/WindowFunctions.hpp +++ b/include/dspbb/Filtering/WindowFunctions.hpp @@ -49,7 +49,7 @@ void HammingWindow(SignalView out) { std::iota(out.begin(), out.end(), U(0.0)); out *= preSize; - Cos(out); + Cos(out, out); out *= U(-0.46); out += U(0.54); } diff --git a/include/dspbb/Math/Functions.hpp b/include/dspbb/Math/Functions.hpp index f1598b7..1455049 100644 --- a/include/dspbb/Math/Functions.hpp +++ b/include/dspbb/Math/Functions.hpp @@ -1,165 +1,135 @@ #pragma once -#include "../Primitives/Signal.hpp" -#include "../Primitives/SignalView.hpp" -#include "../Utility/Algorithm.hpp" - #include -#include +#include +#include +#include +#include +#include #include +#include namespace dspbb { -//------------------------------------------------------------------------------ -// Complex number functions -//------------------------------------------------------------------------------ +#define DSPBB_IMPL_FUNCTION_2_PARAM(NAME, FUNC) \ + template && is_same_domain_v, std::decay_t>, int> = 0> \ + auto NAME(SignalT&& out, const SignalU& in) { \ + return UnaryOperationVectorized(out.Data(), in.Data(), out.Length(), [](auto v) { return math_functions::FUNC(v); }); \ + } +#define DSPBB_IMPL_FUNCTION_1_PARAM(NAME) \ + template >, int> = 0> \ + auto NAME(const SignalT& signal) { \ + SignalT r(signal.Size()); \ + NAME(r, signal); \ + return r; \ + } -template -auto Abs(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::abs(v); }); -} +#define DSPBB_IMPL_FUNCTION(NAME, FUNC) \ + DSPBB_IMPL_FUNCTION_2_PARAM(NAME, FUNC) \ + DSPBB_IMPL_FUNCTION_1_PARAM(NAME) -template ::value_type>, int> = 0> -auto Arg(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::arg(v); }); -} - -template ::value_type>, int> = 0> -auto Real(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return v; }); -} -template ::value_type>, int> = 0> -auto Real(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::real(v); }); -} - -template ::value_type>, int> = 0> -auto Imag(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::imag(v); }); -} +//------------------------------------------------------------------------------ +// Complex number functions +//------------------------------------------------------------------------------ +#define DSPBB_IMPL_COMPLEX_FUNCTION_2_PARAM(NAME, VECOP, OP, FUNC) \ + template && is_same_domain_v, std::decay_t>, int> = 0> \ + auto NAME(SignalT&& out, const SignalU& in, int, std::complex) { \ + \ + return UnaryOperationVectorized(out.Data(), \ + in.Data(), \ + out.Length(), \ + complex_functions::VECOP::stride, \ + complex_functions::VECOP{}, \ + complex_functions::OP{}); \ + } \ + \ + template && is_same_domain_v, std::decay_t>, int> = 0> \ + auto NAME(SignalT&& out, const SignalU& in, int, ...) { \ + \ + return UnaryOperationVectorized(out.Data(), in.Data(), out.Length(), [](auto v) { return math_functions::FUNC(v); }); \ + } \ + \ + template && is_same_domain_v, std::decay_t>, int> = 0> \ + auto NAME(SignalT&& out, const SignalU& in) { \ + return NAME(std::forward(out), in, 0, typename signal_traits::type{}); \ + } + +#define DSPBB_IMPL_COMPLEX_FUNCTION_1_PARAM(NAME, FUNC) \ + template >, int> = 0> \ + auto NAME(const SignalT& signal) { \ + using R = decltype(std::FUNC(std::declval::type>())); \ + Signal::domain> r(signal.Size()); \ + NAME(r, signal); \ + return r; \ + } + +#define DSPBB_IMPL_COMPLEX_FUNCTION(NAME, VECOP, OP, FUNC) \ + DSPBB_IMPL_COMPLEX_FUNCTION_2_PARAM(NAME, VECOP, OP, FUNC) \ + DSPBB_IMPL_COMPLEX_FUNCTION_1_PARAM(NAME, FUNC) + +DSPBB_IMPL_COMPLEX_FUNCTION(Abs, AbsVec, Abs, abs) +DSPBB_IMPL_COMPLEX_FUNCTION(Arg, ArgVec, Arg, arg) +DSPBB_IMPL_COMPLEX_FUNCTION(Real, RealVec, Real, real) +DSPBB_IMPL_COMPLEX_FUNCTION(Imag, ImagVec, Imag, imag) //------------------------------------------------------------------------------ // Exponential functions //------------------------------------------------------------------------------ -template -auto Log(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::log(v); }); -} - -template -auto Log2(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::log2(v); }); -} - -template -auto Log10(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::log10(v); }); -} - -template -auto Exp(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::exp(v); }); -} - +DSPBB_IMPL_FUNCTION(Log, log) +DSPBB_IMPL_FUNCTION(Log2, log2) +DSPBB_IMPL_FUNCTION(Log10, log10) +DSPBB_IMPL_FUNCTION(Exp, exp) //------------------------------------------------------------------------------ // Polynomial functions //------------------------------------------------------------------------------ -template -auto Pow(SignalT&& signal, typename std::decay_t::value_type power) { - return Apply( - std::forward(signal), - [](typename std::decay_t::value_type v, typename std::decay_t::value_type power) { return std::pow(v, power); }, power); -} - -template -auto Sqrt(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::sqrt(v); }); +template && is_same_domain_v, std::decay_t>, int> = 0> +auto Pow(SignalT&& out, const SignalU& in, std::remove_const_t::value_type> power) { + return UnaryOperationVectorized(out.Data(), in.Data(), out.Length(), [power](auto v) { return math_functions::pow(v, power); }); } - -template -auto Cbrt(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::cbrt(v); }); +template >, int> = 0> +auto Pow(const SignalT& signal, std::remove_const_t::value_type> power) { + SignalT r(signal.Size()); + Pow(r, signal, power); + return r; } +DSPBB_IMPL_FUNCTION(Sqrt, sqrt) +DSPBB_IMPL_FUNCTION(Cbrt, cbrt) //------------------------------------------------------------------------------ // Trigonometric functions //------------------------------------------------------------------------------ -template -auto Sin(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::sin(v); }); -} - -template -auto Cos(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::cos(v); }); -} - -template -auto Tan(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::tan(v); }); -} - -template -auto Asin(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::asin(v); }); -} - -template -auto Acos(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::acos(v); }); -} - -template -auto Atan(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::atan(v); }); -} +DSPBB_IMPL_FUNCTION(Sin, sin) +DSPBB_IMPL_FUNCTION(Cos, cos) +DSPBB_IMPL_FUNCTION(Tan, tan) +DSPBB_IMPL_FUNCTION(Asin, asin) +DSPBB_IMPL_FUNCTION(Acos, acos) +DSPBB_IMPL_FUNCTION(Atan, atan) //------------------------------------------------------------------------------ // Hyperbolic functions //------------------------------------------------------------------------------ -template -auto Sinh(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::sinh(v); }); -} - -template -auto Cosh(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::cosh(v); }); -} - - -template -auto Tanh(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::tanh(v); }); -} - -template -auto Asinh(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::asinh(v); }); -} - -template -auto Acosh(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::acosh(v); }); -} - -template -auto Atanh(SignalT&& signal) { - return Apply(std::forward(signal), [](typename std::decay_t::value_type v) { return std::atanh(v); }); -} +DSPBB_IMPL_FUNCTION(Sinh, sinh) +DSPBB_IMPL_FUNCTION(Cosh, cosh) +DSPBB_IMPL_FUNCTION(Tanh, tanh) +DSPBB_IMPL_FUNCTION(Asinh, asinh) +DSPBB_IMPL_FUNCTION(Acosh, acosh) +DSPBB_IMPL_FUNCTION(Atanh, atanh) } // namespace dspbb \ No newline at end of file diff --git a/include/dspbb/Vectorization/ComplexFunctions.hpp b/include/dspbb/Vectorization/ComplexFunctions.hpp new file mode 100644 index 0000000..498ca8e --- /dev/null +++ b/include/dspbb/Vectorization/ComplexFunctions.hpp @@ -0,0 +1,91 @@ +#pragma once + +#include + +namespace dspbb { +namespace complex_functions { + + template + struct AbsVec { + static constexpr size_t stride = xsimd::simd_traits>::size; + using complex_vector = xsimd::batch, stride>; + using real_vector = xsimd::batch; + void operator()(T* out, const std::complex* in) { + complex_vector vin; + vin.load_unaligned(in); + const real_vector vout = xsimd::abs(vin); + vout.store_unaligned(out); + } + }; + + template + struct Abs { + void operator()(T* out, const std::complex* in) { + *out = std::abs(*in); + } + }; + + + template + struct ArgVec { + static constexpr size_t stride = xsimd::simd_traits>::size; + using complex_vector = xsimd::batch, stride>; + using real_vector = xsimd::batch; + void operator()(T* out, const std::complex* in) { + complex_vector vin; + vin.load_unaligned(in); + const real_vector vout = xsimd::arg(vin); + vout.store_unaligned(out); + } + }; + + template + struct Arg { + void operator()(T* out, const std::complex* in) { + *out = std::arg(*in); + } + }; + + + template + struct RealVec { + static constexpr size_t stride = xsimd::simd_traits>::size; + using complex_vector = xsimd::batch, stride>; + using real_vector = xsimd::batch; + void operator()(T* out, const std::complex* in) { + complex_vector vin; + vin.load_unaligned(in); + const real_vector vout = xsimd::real(vin); + vout.store_unaligned(out); + } + }; + + template + struct Real { + void operator()(T* out, const std::complex* in) { + *out = std::real(*in); + } + }; + + template + struct ImagVec { + static constexpr size_t stride = xsimd::simd_traits>::size; + using complex_vector = xsimd::batch, stride>; + using real_vector = xsimd::batch; + void operator()(T* out, const std::complex* in) { + complex_vector vin; + vin.load_unaligned(in); + const real_vector vout = xsimd::imag(vin); + vout.store_unaligned(out); + } + }; + + template + struct Imag { + void operator()(T* out, const std::complex* in) { + *out = std::imag(*in); + } + }; + +} // namespace complex_functions +} // namespace dspbb \ No newline at end of file diff --git a/include/dspbb/Vectorization/Kernels.hpp b/include/dspbb/Vectorization/Kernels.hpp index 2840ff6..91c887f 100644 --- a/include/dspbb/Vectorization/Kernels.hpp +++ b/include/dspbb/Vectorization/Kernels.hpp @@ -148,5 +148,20 @@ void UnaryOperationVectorized(T* out, T* in, size_t length, Op op) { UnaryOperation(out, in, length - vlength, op); } +template +void UnaryOperationVectorized(R* out, T* in, size_t length, size_t stride, VecOp vop, Op op) { + const size_t vlength = (length / stride) * stride; + + const R* vlast = out + vlength; + const R* last = out + length; + + for (; out < vlast; out += stride, in += stride) { + vop(out, in); + } + for (; out < last; out += 1, in += 1) { + op(out, in); + } +} + } // namespace dspbb diff --git a/include/dspbb/Vectorization/MathFunctions.hpp b/include/dspbb/Vectorization/MathFunctions.hpp index 6fe08fa..6d1d2a6 100644 --- a/include/dspbb/Vectorization/MathFunctions.hpp +++ b/include/dspbb/Vectorization/MathFunctions.hpp @@ -4,36 +4,67 @@ #include namespace dspbb { - - - namespace math_functions { - using std::log; - using std::log2; - using std::log10; - using std::exp; - - using xsimd::log; - using xsimd::log2; - using xsimd::log10; - using xsimd::exp; - - - using std::pow; - using std::sqrt; - using std::cbrt; - - using xsimd::pow; - using xsimd::sqrt; - using xsimd::cbrt; - - - } - - - - - - - - -} \ No newline at end of file +namespace math_functions { + // Exponential + using std::exp; + using std::log; + using std::log10; + using std::log2; + + using xsimd::exp; + using xsimd::log; + using xsimd::log10; + using xsimd::log2; + + // Polynomial + using std::cbrt; + using std::pow; + using std::sqrt; + + using xsimd::cbrt; + using xsimd::pow; + using xsimd::sqrt; + + // Trigonometric + using std::acos; + using std::asin; + using std::atan; + using std::cos; + using std::sin; + using std::tan; + + using xsimd::acos; + using xsimd::asin; + using xsimd::atan; + using xsimd::cos; + using xsimd::sin; + using xsimd::tan; + + // Hyperbolic + using xsimd::acosh; + using xsimd::asinh; + using xsimd::atanh; + using xsimd::cosh; + using xsimd::sinh; + using xsimd::tanh; + + using std::acosh; + using std::asinh; + using std::atanh; + using std::cosh; + using std::sinh; + using std::tanh; + + // Complex + using std::abs; + using std::arg; + using std::real; + using std::imag; + + using xsimd::abs; + using xsimd::arg; + using xsimd::real; + using xsimd::imag; + +} // namespace math_functions +} // namespace dspbb \ No newline at end of file diff --git a/test/Test_Functions.cpp b/test/Test_Functions.cpp index bb90e8f..6e144fe 100644 --- a/test/Test_Functions.cpp +++ b/test/Test_Functions.cpp @@ -14,13 +14,13 @@ auto iden(T arg) { } // namespace std -#define TEST_CASE_FUNCTION_REAL(NAME, FUNC, STDFUNC) \ - TEST_CASE(NAME " real", "[Functions]") { \ - const TimeSignal signal = { 1, 8 }; \ - const auto applied = FUNC(signal); \ - for (size_t i = 0; i < signal.Size(); ++i) { \ - REQUIRE(Approx(applied[i]) == std::STDFUNC(signal[i])); \ - } \ +#define TEST_CASE_FUNCTION_REAL(NAME, FUNC, STDFUNC) \ + TEST_CASE(NAME " real", "[Functions]") { \ + const TimeSignal signal = { 1, 8, 2, 5, 3, 6, 3, 6, 4 }; \ + const auto applied = FUNC(signal); \ + for (size_t i = 0; i < signal.Size(); ++i) { \ + REQUIRE(Approx(applied[i]) == std::STDFUNC(signal[i])); \ + } \ }