From 9a9466e9438969758450de96d273f392b0179d20 Mon Sep 17 00:00:00 2001 From: q66 Date: Sun, 19 Mar 2017 16:23:00 +0100 Subject: [PATCH] make channels copyable (referring to a shared state) --- examples/concurrency.cc | 12 ++--- ostd/channel.hh | 97 ++++++++++++++++++++++++++--------------- ostd/concurrency.hh | 16 +++---- 3 files changed, 74 insertions(+), 51 deletions(-) diff --git a/examples/concurrency.cc b/examples/concurrency.cc index 840b6fd..e4290a5 100644 --- a/examples/concurrency.cc +++ b/examples/concurrency.cc @@ -11,15 +11,15 @@ void foo(S &sched) { auto h2 = arr.slice(arr.size() / 2, arr.size()); auto c = sched.template make_channel(); - auto f = [c](auto half) { - c->put(foldl(half, 0)); + auto f = [](auto c, auto half) { + c.put(foldl(half, 0)); }; - sched.spawn(f, h1); - sched.spawn(f, h2); + sched.spawn(f, c, h1); + sched.spawn(f, c, h2); int a, b; - c->get(a); - c->get(b); + c.get(a); + c.get(b); writefln("first half: %s", a); writefln("second half: %s", b); writefln("total: %s", a + b); diff --git a/ostd/channel.hh b/ostd/channel.hh index cb74cf8..11fe045 100644 --- a/ostd/channel.hh +++ b/ostd/channel.hh @@ -10,6 +10,7 @@ #include #include #include +#include namespace ostd { @@ -22,72 +23,96 @@ struct channel { using condition_variable_type = C; /* default ctor works for default C */ - channel() {} + channel(): p_state(new impl) {} /* constructing using a function object, keep in mind that condvars are * not copy or move constructible, so the func has to work in a way that * elides copying and moving (by directly returning the type ctor call) */ template - channel(F func): p_lock(), p_cond(func(p_lock)) {} + channel(F func): p_state(new impl{func}) {} + + channel(channel const &) = default; + channel(channel &&) = default; + + channel &operator=(channel const &) = default; + channel &operator=(channel &&) = default; void put(T const &val) { - put_impl(val); + p_state->put(val); } void put(T &&val) { - put_impl(std::move(val)); + p_state->put(std::move(val)); } bool get(T &val) { - return get_impl(val, true); + return p_state->get(val, true); } bool try_get(T &val) { - return get_impl(val, false); + return p_state->get(val, false); } bool is_closed() const { - std::unique_lock l{p_lock}; - return p_closed; + return p_state->is_closed(); } void close() { - std::unique_lock l{p_lock}; - p_closed = true; - p_cond.notify_all(); + p_state->close(); } private: - template - void put_impl(U &&val) { - std::unique_lock l{p_lock}; - if (p_closed) { - throw channel_error{"put in a closed channel"}; - } - p_messages.push_back(std::forward(val)); - p_cond.notify_one(); - } + struct impl { + impl() {} - bool get_impl(T &val, bool w) { - std::unique_lock l{p_lock}; - if (w) { - while (!p_closed && p_messages.empty()) { - p_cond.wait(l); + template + impl(F &func): p_lock(), p_cond(func(p_lock)) {} + + template + void put(U &&val) { + std::unique_lock l{p_lock}; + if (p_closed) { + throw channel_error{"put in a closed channel"}; } + p_messages.push_back(std::forward(val)); + p_cond.notify_one(); } - if (p_messages.empty()) { - return false; - } - val = p_messages.front(); - p_messages.pop_front(); - return true; - } - std::list p_messages; - mutable std::mutex p_lock; - C p_cond; - bool p_closed = false; + bool get(T &val, bool w) { + std::unique_lock l{p_lock}; + if (w) { + while (!p_closed && p_messages.empty()) { + p_cond.wait(l); + } + } + if (p_messages.empty()) { + return false; + } + val = p_messages.front(); + p_messages.pop_front(); + return true; + } + + bool is_closed() const { + std::unique_lock l{p_lock}; + return p_closed; + } + + void close() { + std::unique_lock l{p_lock}; + p_closed = true; + p_cond.notify_all(); + } + + std::list p_messages; + mutable std::mutex p_lock; + C p_cond; + bool p_closed = false; + }; + + /* basic and inefficient, deal with it better later */ + std::shared_ptr p_state; }; } /* namespace ostd */ diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index 7ce7297..6269a2f 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -44,8 +44,8 @@ struct thread_scheduler { } template - std::shared_ptr> make_channel() { - return std::shared_ptr>{new channel{}}; + channel make_channel() { + return channel{}; } private: @@ -143,7 +143,7 @@ public: } else { p_coros.emplace_back([lfunc = std::bind( std::forward(func), std::forward(args)... - )](auto) { + )](auto) mutable { lfunc(); }); } @@ -160,12 +160,10 @@ public: } template - std::shared_ptr> make_channel() { - return std::shared_ptr>{ - new channel{[this](std::mutex &mtx) { - return coro_cond{*this, mtx}; - }} - }; + channel make_channel() { + return channel{[this](std::mutex &mtx) { + return coro_cond{*this, mtx}; + }}; } private: struct coro: coroutine {