make channels copyable (referring to a shared state)
parent
9ad7fe76c4
commit
9a9466e943
|
@ -11,15 +11,15 @@ void foo(S &sched) {
|
||||||
auto h2 = arr.slice(arr.size() / 2, arr.size());
|
auto h2 = arr.slice(arr.size() / 2, arr.size());
|
||||||
|
|
||||||
auto c = sched.template make_channel<int>();
|
auto c = sched.template make_channel<int>();
|
||||||
auto f = [c](auto half) {
|
auto f = [](auto c, auto half) {
|
||||||
c->put(foldl(half, 0));
|
c.put(foldl(half, 0));
|
||||||
};
|
};
|
||||||
sched.spawn(f, h1);
|
sched.spawn(f, c, h1);
|
||||||
sched.spawn(f, h2);
|
sched.spawn(f, c, h2);
|
||||||
|
|
||||||
int a, b;
|
int a, b;
|
||||||
c->get(a);
|
c.get(a);
|
||||||
c->get(b);
|
c.get(b);
|
||||||
writefln("first half: %s", a);
|
writefln("first half: %s", a);
|
||||||
writefln("second half: %s", b);
|
writefln("second half: %s", b);
|
||||||
writefln("total: %s", a + b);
|
writefln("total: %s", a + b);
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace ostd {
|
namespace ostd {
|
||||||
|
|
||||||
|
@ -22,72 +23,96 @@ struct channel {
|
||||||
using condition_variable_type = C;
|
using condition_variable_type = C;
|
||||||
|
|
||||||
/* default ctor works for default C */
|
/* default ctor works for default C */
|
||||||
channel() {}
|
channel(): p_state(new impl) {}
|
||||||
|
|
||||||
/* constructing using a function object, keep in mind that condvars are
|
/* constructing using a function object, keep in mind that condvars are
|
||||||
* not copy or move constructible, so the func has to work in a way that
|
* not copy or move constructible, so the func has to work in a way that
|
||||||
* elides copying and moving (by directly returning the type ctor call)
|
* elides copying and moving (by directly returning the type ctor call)
|
||||||
*/
|
*/
|
||||||
template<typename F>
|
template<typename F>
|
||||||
channel(F func): p_lock(), p_cond(func(p_lock)) {}
|
channel(F func): p_state(new impl{func}) {}
|
||||||
|
|
||||||
|
channel(channel const &) = default;
|
||||||
|
channel(channel &&) = default;
|
||||||
|
|
||||||
|
channel &operator=(channel const &) = default;
|
||||||
|
channel &operator=(channel &&) = default;
|
||||||
|
|
||||||
void put(T const &val) {
|
void put(T const &val) {
|
||||||
put_impl(val);
|
p_state->put(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
void put(T &&val) {
|
void put(T &&val) {
|
||||||
put_impl(std::move(val));
|
p_state->put(std::move(val));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool get(T &val) {
|
bool get(T &val) {
|
||||||
return get_impl(val, true);
|
return p_state->get(val, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool try_get(T &val) {
|
bool try_get(T &val) {
|
||||||
return get_impl(val, false);
|
return p_state->get(val, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_closed() const {
|
bool is_closed() const {
|
||||||
std::unique_lock<std::mutex> l{p_lock};
|
return p_state->is_closed();
|
||||||
return p_closed;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void close() {
|
void close() {
|
||||||
std::unique_lock<std::mutex> l{p_lock};
|
p_state->close();
|
||||||
p_closed = true;
|
|
||||||
p_cond.notify_all();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template<typename U>
|
struct impl {
|
||||||
void put_impl(U &&val) {
|
impl() {}
|
||||||
std::unique_lock<std::mutex> l{p_lock};
|
|
||||||
if (p_closed) {
|
|
||||||
throw channel_error{"put in a closed channel"};
|
|
||||||
}
|
|
||||||
p_messages.push_back(std::forward<U>(val));
|
|
||||||
p_cond.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool get_impl(T &val, bool w) {
|
template<typename F>
|
||||||
std::unique_lock<std::mutex> l{p_lock};
|
impl(F &func): p_lock(), p_cond(func(p_lock)) {}
|
||||||
if (w) {
|
|
||||||
while (!p_closed && p_messages.empty()) {
|
template<typename U>
|
||||||
p_cond.wait(l);
|
void put(U &&val) {
|
||||||
|
std::unique_lock<std::mutex> l{p_lock};
|
||||||
|
if (p_closed) {
|
||||||
|
throw channel_error{"put in a closed channel"};
|
||||||
}
|
}
|
||||||
|
p_messages.push_back(std::forward<U>(val));
|
||||||
|
p_cond.notify_one();
|
||||||
}
|
}
|
||||||
if (p_messages.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
val = p_messages.front();
|
|
||||||
p_messages.pop_front();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::list<T> p_messages;
|
bool get(T &val, bool w) {
|
||||||
mutable std::mutex p_lock;
|
std::unique_lock<std::mutex> l{p_lock};
|
||||||
C p_cond;
|
if (w) {
|
||||||
bool p_closed = false;
|
while (!p_closed && p_messages.empty()) {
|
||||||
|
p_cond.wait(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (p_messages.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
val = p_messages.front();
|
||||||
|
p_messages.pop_front();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_closed() const {
|
||||||
|
std::unique_lock<std::mutex> l{p_lock};
|
||||||
|
return p_closed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void close() {
|
||||||
|
std::unique_lock<std::mutex> l{p_lock};
|
||||||
|
p_closed = true;
|
||||||
|
p_cond.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::list<T> p_messages;
|
||||||
|
mutable std::mutex p_lock;
|
||||||
|
C p_cond;
|
||||||
|
bool p_closed = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* basic and inefficient, deal with it better later */
|
||||||
|
std::shared_ptr<impl> p_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
} /* namespace ostd */
|
} /* namespace ostd */
|
||||||
|
|
|
@ -44,8 +44,8 @@ struct thread_scheduler {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
std::shared_ptr<channel<T>> make_channel() {
|
channel<T> make_channel() {
|
||||||
return std::shared_ptr<channel<T>>{new channel<T>{}};
|
return channel<T>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -143,7 +143,7 @@ public:
|
||||||
} else {
|
} else {
|
||||||
p_coros.emplace_back([lfunc = std::bind(
|
p_coros.emplace_back([lfunc = std::bind(
|
||||||
std::forward<F>(func), std::forward<A>(args)...
|
std::forward<F>(func), std::forward<A>(args)...
|
||||||
)](auto) {
|
)](auto) mutable {
|
||||||
lfunc();
|
lfunc();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -160,12 +160,10 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
std::shared_ptr<channel<T, coro_cond>> make_channel() {
|
channel<T, coro_cond> make_channel() {
|
||||||
return std::shared_ptr<channel<T, coro_cond>>{
|
return channel<T, coro_cond>{[this](std::mutex &mtx) {
|
||||||
new channel<T, coro_cond>{[this](std::mutex &mtx) {
|
return coro_cond{*this, mtx};
|
||||||
return coro_cond{*this, mtx};
|
}};
|
||||||
}}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
struct coro: coroutine<void()> {
|
struct coro: coroutine<void()> {
|
||||||
|
|
Loading…
Reference in New Issue