Skip to content

Commit

Permalink
simplify MPSC to use WaitStrategy (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
geseq authored Feb 29, 2024
1 parent de77b79 commit 0b965d0
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 238 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ auto val = c.get();

```cpp
// MPSC
fastchan::MPSC<int, blockingType, chan_size> c;
fastchan::MPSC<int, chan_size> c;
// OR
fastchan::MPSC<int, blockingType, chan_size, fastchan::WaitPause> c;
fastchan::MPSC<int, chan_size, fastchan::PauseWaitStrategy, fastchan::PauseWaitStrategy> c;

c.put(0);
c.put(1);
Expand Down
254 changes: 127 additions & 127 deletions bench/fastchan_bench.cpp

Large diffs are not rendered by default.

14 changes: 0 additions & 14 deletions include/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@ inline void cpu_pause() {
#endif
}

enum BlockingType {
BlockingPutBlockingGet,
BlockingPutNonBlockingGet,
NonBlockingPutBlockingGet,
NonBlockingPutNonBlockingGet,
};

enum WaitType {
WaitPause,
WaitYield,
WaitCondition,
WaitNoOp,
};

constexpr size_t roundUpNextPowerOfTwo(size_t v) {
v--;
for (size_t i = 1; i < sizeof(v) * CHAR_BIT; i *= 2) {
Expand Down
100 changes: 45 additions & 55 deletions include/mpsc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,115 +6,105 @@
#include <thread>

#include "common.hpp"
#include "wait_strategy.hpp"

namespace fastchan {

template <typename T, BlockingType blocking_type, size_t min_size, WaitType wait_type = WaitYield>
template <typename T, size_t min_size, class PutWaitStrategy = YieldWaitStrategy, class GetWaitStrategy = YieldWaitStrategy>
class MPSC {
public:
using put_t = typename std::conditional<(blocking_type == BlockingPutBlockingGet || blocking_type == BlockingPutNonBlockingGet), void, bool>::type;
using get_t = typename std::conditional<(blocking_type == BlockingPutBlockingGet || blocking_type == NonBlockingPutBlockingGet), T, std::optional<T>>::type;
using put_t = typename std::conditional<!std::is_same<PutWaitStrategy, ReturnImmediateStrategy>::value, void, bool>::type;
using get_t = typename std::conditional<!std::is_same<GetWaitStrategy, ReturnImmediateStrategy>::value, T, std::optional<T>>::type;

MPSC() = default;

put_t put(const T &value) noexcept {
auto write_index = next_free_index_.load(std::memory_order_acquire);
do {
while (write_index > (reader_index_.load(std::memory_order_relaxed) + index_mask_)) {
if constexpr (blocking_type == BlockingPutBlockingGet || blocking_type == BlockingPutNonBlockingGet) {
if constexpr (wait_type == WaitYield) {
std::this_thread::yield();
} else if constexpr (wait_type == WaitCondition) {
std::unique_lock<std::mutex> lock(put_mutex_);
put_cv_.wait(lock, [this, write_index] { return write_index <= (reader_index_.load(std::memory_order_relaxed) + index_mask_); });
} else if constexpr (wait_type == WaitPause) {
cpu_pause();
}
} else {
while (write_index > (consumer_.reader_index_.load(std::memory_order_relaxed) + common_.index_mask_)) {
if constexpr (std::is_same<PutWaitStrategy, ReturnImmediateStrategy>::value) {
return false;
} else {
common_.put_wait_.wait(
[this, write_index] { return write_index <= (consumer_.reader_index_.load(std::memory_order_relaxed) + common_.index_mask_); });
}
}
} while (!next_free_index_.compare_exchange_strong(write_index, write_index + 1, std::memory_order_acq_rel, std::memory_order_acquire));

contents_[write_index & index_mask_] = value;
contents_[write_index & common_.index_mask_] = value;

// commit in the correct order to avoid problems
while (last_committed_index_.load(std::memory_order_relaxed) != write_index) {
if constexpr (wait_type == WaitYield) {
std::this_thread::yield();
} else if constexpr (wait_type == WaitPause) {
cpu_pause();
}
// we don't return at this point even in case of ReturnImmediatelyStrategy as we've already taken the token
common_.put_wait_.wait([this, write_index] { return last_committed_index_.load(std::memory_order_relaxed) == write_index; });
}

last_committed_index_.store(++write_index, std::memory_order_release);

if constexpr (wait_type == WaitCondition) {
std::lock_guard<std::mutex> lock(get_mutex_);
get_cv_.notify_one();
}
common_.get_wait_.notify();
common_.put_wait_.notify();

if constexpr (blocking_type != BlockingPutBlockingGet && blocking_type != BlockingPutNonBlockingGet) {
if constexpr (std::is_same<PutWaitStrategy, ReturnImmediateStrategy>::value) {
return true;
}
}

get_t get() noexcept {
while (reader_index_2_ >= last_committed_index_.load(std::memory_order_relaxed)) {
if constexpr (blocking_type == BlockingPutBlockingGet || blocking_type == NonBlockingPutBlockingGet) {
if constexpr (wait_type == WaitYield) {
std::this_thread::yield();
} else if constexpr (wait_type == WaitCondition) {
std::unique_lock<std::mutex> lock(get_mutex_);
get_cv_.wait(lock, [this] { return reader_index_2_ < last_committed_index_.load(std::memory_order_relaxed); });
} else if constexpr (wait_type == WaitPause) {
cpu_pause();
}
} else {
while (consumer_.reader_index_2_ >= last_committed_index_.load(std::memory_order_relaxed)) {
if constexpr (std::is_same<GetWaitStrategy, ReturnImmediateStrategy>::value) {
return std::nullopt;
} else {
common_.get_wait_.wait([this] { return consumer_.reader_index_2_ < last_committed_index_.load(std::memory_order_relaxed); });
}
}

auto contents = contents_[reader_index_2_ & index_mask_];
reader_index_.store(++reader_index_2_, std::memory_order_release);
auto contents = contents_[consumer_.reader_index_2_ & common_.index_mask_];
consumer_.reader_index_.store(++consumer_.reader_index_2_, std::memory_order_release);

common_.put_wait_.notify();

if constexpr (wait_type == WaitCondition) {
std::lock_guard<std::mutex> lock(put_mutex_);
put_cv_.notify_one();
}
return contents;
}

void empty() noexcept {
reader_index_2_ = 0;
consumer_.reader_index_2_ = 0;
next_free_index_.store(0, std::memory_order_release);
last_committed_index_.store(0, std::memory_order_release);
reader_index_.store(0, std::memory_order_release);
consumer_.reader_index_.store(0, std::memory_order_release);
}

std::size_t size() const noexcept { return last_committed_index_.load(std::memory_order_acquire) - reader_index_.load(std::memory_order_acquire); }
std::size_t size() const noexcept {
return last_committed_index_.load(std::memory_order_acquire) - consumer_.reader_index_.load(std::memory_order_acquire);
}

bool isEmpty() const noexcept { return reader_index_.load(std::memory_order_acquire) >= last_committed_index_.load(std::memory_order_acquire); }
bool isEmpty() const noexcept { return consumer_.reader_index_.load(std::memory_order_acquire) >= last_committed_index_.load(std::memory_order_acquire); }

bool isFull() const noexcept {
// this isFull is about whether there's all writer slots to the buffer are taken rather than whether those
// changes have actually been committed
return next_free_index_.load(std::memory_order_acquire) > (reader_index_.load(std::memory_order_acquire) + index_mask_);
return next_free_index_.load(std::memory_order_acquire) > (consumer_.reader_index_.load(std::memory_order_acquire) + common_.index_mask_);
}

private:
const std::size_t index_mask_ = roundUpNextPowerOfTwo(min_size) - 1;
alignas(64) std::size_t reader_index_2_{0};
alignas(64) std::atomic<std::size_t> reader_index_{0};
std::array<T, roundUpNextPowerOfTwo(min_size)> contents_;

alignas(64) std::atomic<std::size_t> next_free_index_{0};
alignas(64) std::atomic<std::size_t> last_committed_index_{0};

alignas(64) std::condition_variable put_cv_;
alignas(64) std::condition_variable get_cv_;
alignas(64) std::mutex put_mutex_;
alignas(64) std::mutex get_mutex_;
struct alignas(64) Common {
GetWaitStrategy get_wait_{};
PutWaitStrategy put_wait_{};
const std::size_t index_mask_ = roundUpNextPowerOfTwo(min_size) - 1;
};

struct alignas(64) Consumer {
std::size_t next_free_index_cache_{0};
std::size_t reader_index_2_{0};
std::atomic<std::size_t> reader_index_{0};
};

alignas(64) std::array<T, roundUpNextPowerOfTwo(min_size)> contents_;
Common common_;
Consumer consumer_;
};

} // namespace fastchan
11 changes: 9 additions & 2 deletions include/wait_strategy.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#include <chrono>
#include <condition_variable>
#include <ratio>
#include <thread>

#include "common.hpp"

#ifndef FASTCHANWAIT_HPP
#define FASTCHANWAIT_HPP

namespace fastchan {

// WaitStrategyInterface is the interface for actual implementation of a wait strategy handler
Expand Down Expand Up @@ -51,14 +56,16 @@ class CVWaitStrategy : public WaitStrategyInterface<PauseWaitStrategy> {
template <class Predicate>
void wait(Predicate p) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, p);
cv_.wait_for(lock, std::chrono::nanoseconds(100), p);
}

void notify() { cv_.notify_one(); }
void notify() { cv_.notify_all(); }

private:
std::condition_variable cv_;
std::mutex mutex_;
};

} // namespace fastchan

#endif
77 changes: 39 additions & 38 deletions test/fastchan_mpsc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ using namespace std::chrono_literals;

const auto IterationsMultiplier = 100;

template <fastchan::BlockingType blockingType, int iterations, fastchan::WaitType waitType>
template <int iterations, class put_wait_strategy, class get_wait_strategy>
void testMPSCSingleThreaded() {
constexpr std::size_t chan_size = (iterations / 2) + 1;
fastchan::MPSC<int, blockingType, chan_size, waitType> chan;
fastchan::MPSC<int, chan_size, put_wait_strategy, get_wait_strategy> chan;

assert(chan.size() == 0);
assert(chan.isEmpty() == true);
// Test filling up with a single thread
for (int i = 0; i < iterations; ++i) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::NonBlockingPutBlockingGet) {
if constexpr (std::is_same<put_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto result = false;
do {
result = chan.put(i);
Expand Down Expand Up @@ -49,7 +49,7 @@ void testMPSCSingleThreaded() {

// Test put and get with a single thread
for (int i = 0; i < iterations; ++i) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::NonBlockingPutBlockingGet) {
if constexpr (std::is_same<put_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto result = false;
do {
result = chan.put(i);
Expand All @@ -62,7 +62,7 @@ void testMPSCSingleThreaded() {
}

for (int i = 0; i < iterations; ++i) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::BlockingPutNonBlockingGet) {
if constexpr (std::is_same<get_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto&& val = chan.get();
while (!val) {
val = chan.get();
Expand All @@ -82,16 +82,16 @@ void testMPSCSingleThreaded() {
assert(chan.isEmpty());
}

template <fastchan::BlockingType blockingType, int iterations, fastchan::WaitType waitType>
template <int iterations, class put_wait_strategy, class get_wait_strategy>
void testMPSCMultiThreadedSingleProducer() {
constexpr std::size_t chan_size = (iterations / 2) + 1;
fastchan::MPSC<int, blockingType, chan_size, waitType> chan;
fastchan::MPSC<int, chan_size, put_wait_strategy, get_wait_strategy> chan;

auto total_iterations = IterationsMultiplier * iterations;
// Test put and get with multiple threads
std::thread producer([&] {
for (int i = 1; i <= total_iterations; ++i) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::NonBlockingPutBlockingGet) {
if constexpr (std::is_same<put_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto result = false;
do {
result = chan.put(i);
Expand All @@ -104,7 +104,7 @@ void testMPSCMultiThreadedSingleProducer() {

std::thread consumer([&] {
for (int i = 1; i <= total_iterations;) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::BlockingPutNonBlockingGet) {
if constexpr (std::is_same<get_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto&& val = chan.get();
while (!val) {
val = chan.get();
Expand All @@ -126,10 +126,10 @@ void testMPSCMultiThreadedSingleProducer() {
assert(chan.size() == 0);
}

template <fastchan::BlockingType blockingType, int iterations, int num_threads, fastchan::WaitType waitType>
template <int iterations, int num_threads, class put_wait_strategy, class get_wait_strategy>
void testMPSCMultiThreadedMultiProducer() {
constexpr std::size_t chan_size = (iterations / 2) + 1;
fastchan::MPSC<int, blockingType, chan_size, waitType> chan;
fastchan::MPSC<int, chan_size, put_wait_strategy, get_wait_strategy> chan;

size_t total_iterations = IterationsMultiplier * iterations;
size_t total = num_threads * (total_iterations * (total_iterations + 1) / 2);
Expand All @@ -140,7 +140,7 @@ void testMPSCMultiThreadedMultiProducer() {
// Test put and get with multiple threads
producers[i] = std::thread([&] {
for (int i = 1; i <= total_iterations; ++i) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::NonBlockingPutBlockingGet) {
if constexpr (std::is_same<put_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto result = false;
do {
result = chan.put(i);
Expand All @@ -154,7 +154,7 @@ void testMPSCMultiThreadedMultiProducer() {

std::thread consumer([&] {
for (int i = 1; i <= total_iterations * num_threads;) {
if constexpr (blockingType == fastchan::NonBlockingPutNonBlockingGet || blockingType == fastchan::BlockingPutNonBlockingGet) {
if constexpr (std::is_same<get_wait_strategy, fastchan::ReturnImmediateStrategy>::value) {
auto&& val = chan.get();
while (!val) {
val = chan.get();
Expand All @@ -179,41 +179,42 @@ void testMPSCMultiThreadedMultiProducer() {
assert(chan.size() == 0);
}

template <fastchan::BlockingType blockingType, fastchan::WaitType waitType>
template <class put_wait_type, class get_wait_type>
void testMPSC() {
testMPSCSingleThreaded<blockingType, 4, waitType>();
testMPSCMultiThreadedSingleProducer<blockingType, 4, waitType>();
testMPSCSingleThreaded<4, put_wait_type, get_wait_type>();
testMPSCMultiThreadedSingleProducer<4, put_wait_type, get_wait_type>();
if (std::thread::hardware_concurrency() > 5) {
testMPSCMultiThreadedMultiProducer<blockingType, 4, 3, waitType>();
testMPSCMultiThreadedMultiProducer<blockingType, 4, 5, waitType>();
testMPSCMultiThreadedMultiProducer<4, 3, put_wait_type, get_wait_type>();
testMPSCMultiThreadedMultiProducer<4, 5, put_wait_type, get_wait_type>();
} else {
testMPSCMultiThreadedMultiProducer<blockingType, 4, 2, waitType>();
testMPSCMultiThreadedMultiProducer<4, 2, put_wait_type, get_wait_type>();
}

testMPSCSingleThreaded<blockingType, 4096, waitType>();
testMPSCMultiThreadedSingleProducer<blockingType, 4096, waitType>();
testMPSCSingleThreaded<4096, put_wait_type, get_wait_type>();
testMPSCMultiThreadedSingleProducer<4096, put_wait_type, get_wait_type>();
if (std::thread::hardware_concurrency() > 5) {
testMPSCMultiThreadedMultiProducer<blockingType, 4096, 3, waitType>();
testMPSCMultiThreadedMultiProducer<blockingType, 4096, 5, waitType>();
testMPSCMultiThreadedMultiProducer<4096, 3, put_wait_type, get_wait_type>();
testMPSCMultiThreadedMultiProducer<4096, 5, put_wait_type, get_wait_type>();
} else {
testMPSCMultiThreadedMultiProducer<blockingType, 4096, 2, waitType>();
testMPSCMultiThreadedMultiProducer<4096, 2, put_wait_type, get_wait_type>();
}
}

int main() {
testMPSC<fastchan::BlockingPutBlockingGet, fastchan::WaitPause>();
testMPSC<fastchan::BlockingPutNonBlockingGet, fastchan::WaitPause>();
testMPSC<fastchan::NonBlockingPutBlockingGet, fastchan::WaitPause>();
testMPSC<fastchan::NonBlockingPutNonBlockingGet, fastchan::WaitPause>();

testMPSC<fastchan::BlockingPutBlockingGet, fastchan::WaitYield>();
testMPSC<fastchan::BlockingPutNonBlockingGet, fastchan::WaitYield>();
testMPSC<fastchan::NonBlockingPutBlockingGet, fastchan::WaitYield>();
testMPSC<fastchan::NonBlockingPutNonBlockingGet, fastchan::WaitYield>();

testMPSC<fastchan::BlockingPutBlockingGet, fastchan::WaitCondition>();
testMPSC<fastchan::BlockingPutNonBlockingGet, fastchan::WaitCondition>();
testMPSC<fastchan::NonBlockingPutBlockingGet, fastchan::WaitCondition>();
testMPSC<fastchan::NonBlockingPutNonBlockingGet, fastchan::WaitCondition>();
testMPSC<fastchan::PauseWaitStrategy, fastchan::PauseWaitStrategy>();
testMPSC<fastchan::PauseWaitStrategy, fastchan::ReturnImmediateStrategy>();
testMPSC<fastchan::ReturnImmediateStrategy, fastchan::PauseWaitStrategy>();
testMPSC<fastchan::ReturnImmediateStrategy, fastchan::ReturnImmediateStrategy>();

testMPSC<fastchan::YieldWaitStrategy, fastchan::YieldWaitStrategy>();
testMPSC<fastchan::YieldWaitStrategy, fastchan::ReturnImmediateStrategy>();
testMPSC<fastchan::ReturnImmediateStrategy, fastchan::YieldWaitStrategy>();
testMPSC<fastchan::ReturnImmediateStrategy, fastchan::ReturnImmediateStrategy>();

testMPSC<fastchan::CVWaitStrategy, fastchan::CVWaitStrategy>();
testMPSC<fastchan::CVWaitStrategy, fastchan::ReturnImmediateStrategy>();
testMPSC<fastchan::ReturnImmediateStrategy, fastchan::CVWaitStrategy>();
testMPSC<fastchan::ReturnImmediateStrategy, fastchan::ReturnImmediateStrategy>();

return 0;
}

0 comments on commit 0b965d0

Please sign in to comment.