make channels copyable (referring to a shared state)

master
Daniel Kolesa 2017-03-19 16:23:00 +01:00
parent 9ad7fe76c4
commit 9a9466e943
3 changed files with 74 additions and 51 deletions

View File

@ -11,15 +11,15 @@ void foo(S &sched) {
auto h2 = arr.slice(arr.size() / 2, arr.size()); auto h2 = arr.slice(arr.size() / 2, arr.size());
auto c = sched.template make_channel<int>(); auto c = sched.template make_channel<int>();
auto f = [c](auto half) { auto f = [](auto c, auto half) {
c->put(foldl(half, 0)); c.put(foldl(half, 0));
}; };
sched.spawn(f, h1); sched.spawn(f, c, h1);
sched.spawn(f, h2); sched.spawn(f, c, h2);
int a, b; int a, b;
c->get(a); c.get(a);
c->get(b); c.get(b);
writefln("first half: %s", a); writefln("first half: %s", a);
writefln("second half: %s", b); writefln("second half: %s", b);
writefln("total: %s", a + b); writefln("total: %s", a + b);

View File

@ -10,6 +10,7 @@
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include <stdexcept> #include <stdexcept>
#include <memory>
namespace ostd { namespace ostd {
@ -22,72 +23,96 @@ struct channel {
using condition_variable_type = C; using condition_variable_type = C;
/* default ctor works for default C */ /* default ctor works for default C */
channel() {} channel(): p_state(new impl) {}
/* constructing using a function object, keep in mind that condvars are /* constructing using a function object, keep in mind that condvars are
* not copy or move constructible, so the func has to work in a way that * not copy or move constructible, so the func has to work in a way that
* elides copying and moving (by directly returning the type ctor call) * elides copying and moving (by directly returning the type ctor call)
*/ */
template<typename F> template<typename F>
channel(F func): p_lock(), p_cond(func(p_lock)) {} channel(F func): p_state(new impl{func}) {}
channel(channel const &) = default;
channel(channel &&) = default;
channel &operator=(channel const &) = default;
channel &operator=(channel &&) = default;
void put(T const &val) { void put(T const &val) {
put_impl(val); p_state->put(val);
} }
void put(T &&val) { void put(T &&val) {
put_impl(std::move(val)); p_state->put(std::move(val));
} }
bool get(T &val) { bool get(T &val) {
return get_impl(val, true); return p_state->get(val, true);
} }
bool try_get(T &val) { bool try_get(T &val) {
return get_impl(val, false); return p_state->get(val, false);
} }
bool is_closed() const { bool is_closed() const {
std::unique_lock<std::mutex> l{p_lock}; return p_state->is_closed();
return p_closed;
} }
void close() { void close() {
std::unique_lock<std::mutex> l{p_lock}; p_state->close();
p_closed = true;
p_cond.notify_all();
} }
private: private:
template<typename U> struct impl {
void put_impl(U &&val) { impl() {}
std::unique_lock<std::mutex> l{p_lock};
if (p_closed) {
throw channel_error{"put in a closed channel"};
}
p_messages.push_back(std::forward<U>(val));
p_cond.notify_one();
}
bool get_impl(T &val, bool w) { template<typename F>
std::unique_lock<std::mutex> l{p_lock}; impl(F &func): p_lock(), p_cond(func(p_lock)) {}
if (w) {
while (!p_closed && p_messages.empty()) { template<typename U>
p_cond.wait(l); void put(U &&val) {
std::unique_lock<std::mutex> l{p_lock};
if (p_closed) {
throw channel_error{"put in a closed channel"};
} }
p_messages.push_back(std::forward<U>(val));
p_cond.notify_one();
} }
if (p_messages.empty()) {
return false;
}
val = p_messages.front();
p_messages.pop_front();
return true;
}
std::list<T> p_messages; bool get(T &val, bool w) {
mutable std::mutex p_lock; std::unique_lock<std::mutex> l{p_lock};
C p_cond; if (w) {
bool p_closed = false; while (!p_closed && p_messages.empty()) {
p_cond.wait(l);
}
}
if (p_messages.empty()) {
return false;
}
val = p_messages.front();
p_messages.pop_front();
return true;
}
bool is_closed() const {
std::unique_lock<std::mutex> l{p_lock};
return p_closed;
}
void close() {
std::unique_lock<std::mutex> l{p_lock};
p_closed = true;
p_cond.notify_all();
}
std::list<T> p_messages;
mutable std::mutex p_lock;
C p_cond;
bool p_closed = false;
};
/* basic and inefficient, deal with it better later */
std::shared_ptr<impl> p_state;
}; };
} /* namespace ostd */ } /* namespace ostd */

View File

@ -44,8 +44,8 @@ struct thread_scheduler {
} }
template<typename T> template<typename T>
std::shared_ptr<channel<T>> make_channel() { channel<T> make_channel() {
return std::shared_ptr<channel<T>>{new channel<T>{}}; return channel<T>{};
} }
private: private:
@ -143,7 +143,7 @@ public:
} else { } else {
p_coros.emplace_back([lfunc = std::bind( p_coros.emplace_back([lfunc = std::bind(
std::forward<F>(func), std::forward<A>(args)... std::forward<F>(func), std::forward<A>(args)...
)](auto) { )](auto) mutable {
lfunc(); lfunc();
}); });
} }
@ -160,12 +160,10 @@ public:
} }
template<typename T> template<typename T>
std::shared_ptr<channel<T, coro_cond>> make_channel() { channel<T, coro_cond> make_channel() {
return std::shared_ptr<channel<T, coro_cond>>{ return channel<T, coro_cond>{[this](std::mutex &mtx) {
new channel<T, coro_cond>{[this](std::mutex &mtx) { return coro_cond{*this, mtx};
return coro_cond{*this, mtx}; }};
}}
};
} }
private: private:
struct coro: coroutine<void()> { struct coro: coroutine<void()> {