add a simple coroutine scheduler that multiplexes tasks onto a single thread

master
Daniel Kolesa 2017-03-19 14:11:23 +01:00
parent 9f6d8da9db
commit 9ad7fe76c4
3 changed files with 148 additions and 18 deletions

View File

@ -3,41 +3,53 @@
using namespace ostd;
thread_scheduler sched;
void foo() {
template<typename S>
void foo(S &sched) {
auto arr = ostd::iter({ 150, 38, 76, 25, 67, 18, -15, -38, 25, -10 });
auto h1 = arr.slice(0, arr.size() / 2);
auto h2 = arr.slice(arr.size() / 2, arr.size());
auto c = sched.make_channel<int>();
auto f = [&c](auto half) {
c.put(foldl(half, 0));
auto c = sched.template make_channel<int>();
auto f = [c](auto half) {
c->put(foldl(half, 0));
};
sched.spawn(f, h1);
sched.spawn(f, h2);
int a, b;
c.get(a);
c.get(b);
c->get(a);
c->get(b);
writefln("first half: %s", a);
writefln("second half: %s", b);
writefln("total: %s", a + b);
}
int main() {
sched.start([]() {
writeln("starting...");
foo();
writeln("finishing...");
thread_scheduler tsched;
tsched.start([&tsched]() {
writeln("thread scheduler: starting...");
foo(tsched);
writeln("thread scheduler: finishing...");
});
simple_coroutine_scheduler scsched;
scsched.start([&scsched]() {
writeln("simple coroutine scheduler: starting...");
foo(scsched);
writeln("simple coroutine scheduler: finishing...");
});
}
/*
starting...
thread scheduler: starting...
first half: 356
second half: -20
total: 336
finishing...
thread scheduler: finishing...
simple coroutine scheduler: starting...
first half: 356
second half: -20
total: 336
simple coroutine scheduler: finishing...
*/

View File

@ -29,7 +29,7 @@ struct channel {
* elides copying and moving (by directly returning the type ctor call)
*/
template<typename F>
channel(F func): p_cond(func()) {}
channel(F func): p_lock(), p_cond(func(p_lock)) {}
void put(T const &val) {
put_impl(val);
@ -85,8 +85,8 @@ private:
}
std::list<T> p_messages;
C p_cond;
mutable std::mutex p_lock;
C p_cond;
bool p_closed = false;
};

View File

@ -8,7 +8,9 @@
#include <thread>
#include <utility>
#include <memory>
#include "ostd/coroutine.hh"
#include "ostd/channel.hh"
namespace ostd {
@ -42,8 +44,8 @@ struct thread_scheduler {
}
template<typename T>
channel<T> make_channel() {
return channel<T>{};
std::shared_ptr<channel<T>> make_channel() {
return std::shared_ptr<channel<T>>{new channel<T>{}};
}
private:
@ -72,6 +74,122 @@ private:
std::mutex p_lock;
};
struct simple_coroutine_scheduler {
private:
/* simple one just for channels */
struct coro_cond {
coro_cond() = delete;
coro_cond(coro_cond const &) = delete;
coro_cond(coro_cond &&) = delete;
coro_cond &operator=(coro_cond const &) = delete;
coro_cond &operator=(coro_cond &&) = delete;
coro_cond(simple_coroutine_scheduler &s, std::mutex &mtx):
p_sched(s), p_mtx(mtx)
{}
template<typename L>
void wait(L &l) noexcept {
l.unlock();
while (!p_notified) {
p_sched.yield();
}
p_notified = false;
l.lock();
}
void notify_one() noexcept {
p_notified = true;
p_mtx.unlock();
p_sched.yield();
p_mtx.lock();
}
void notify_all() noexcept {
p_notified = true;
p_mtx.unlock();
p_sched.yield();
p_mtx.lock();
}
private:
simple_coroutine_scheduler &p_sched;
std::mutex &p_mtx;
bool p_notified = false;
};
public:
template<typename F, typename ...A>
auto start(F &&func, A &&...args) -> std::result_of_t<F(A...)> {
using R = std::result_of_t<F(A...)>;
if constexpr(std::is_same_v<R, void>) {
spawn(std::forward<F>(func), std::forward<F>(args)...);
dispatch();
} else {
R ret;
spawn([lfunc = std::forward<F>(func), &ret](A &&...args) {
ret = std::move(lfunc(std::forward<A>(args)...));
}, std::forward<A>(args)...);
dispatch();
return ret;
}
}
template<typename F, typename ...A>
void spawn(F &&func, A &&...args) {
if constexpr(sizeof...(A) == 0) {
p_coros.emplace_back([lfunc = std::forward<F>(func)](auto) {
lfunc();
});
} else {
p_coros.emplace_back([lfunc = std::bind(
std::forward<F>(func), std::forward<A>(args)...
)](auto) {
lfunc();
});
}
}
void yield() {
auto ctx = coroutine_context::current();
coro *c = dynamic_cast<coro *>(ctx);
if (c) {
coro::yield_type{*c}();
return;
}
throw std::runtime_error{"no task to yield"};
}
template<typename T>
std::shared_ptr<channel<T, coro_cond>> make_channel() {
return std::shared_ptr<channel<T, coro_cond>>{
new channel<T, coro_cond>{[this](std::mutex &mtx) {
return coro_cond{*this, mtx};
}}
};
}
private:
struct coro: coroutine<void()> {
using coroutine<void()>::coroutine;
};
void dispatch() {
while (!p_coros.empty()) {
if (p_idx == p_coros.end()) {
p_idx = p_coros.begin();
}
(*p_idx)();
if (!*p_idx) {
p_idx = p_coros.erase(p_idx);
} else {
++p_idx;
}
}
}
std::list<coro> p_coros;
typename std::list<coro>::iterator p_idx = p_coros.end();
};
} /* namespace ostd */
#endif