diff --git a/examples/concurrency.cc b/examples/concurrency.cc index e540fcf..840b6fd 100644 --- a/examples/concurrency.cc +++ b/examples/concurrency.cc @@ -3,41 +3,53 @@ using namespace ostd; -thread_scheduler sched; - -void foo() { +template +void foo(S &sched) { auto arr = ostd::iter({ 150, 38, 76, 25, 67, 18, -15, -38, 25, -10 }); auto h1 = arr.slice(0, arr.size() / 2); auto h2 = arr.slice(arr.size() / 2, arr.size()); - auto c = sched.make_channel(); - auto f = [&c](auto half) { - c.put(foldl(half, 0)); + auto c = sched.template make_channel(); + auto f = [c](auto half) { + c->put(foldl(half, 0)); }; sched.spawn(f, h1); sched.spawn(f, h2); int a, b; - c.get(a); - c.get(b); + c->get(a); + c->get(b); writefln("first half: %s", a); writefln("second half: %s", b); writefln("total: %s", a + b); } int main() { - sched.start([]() { - writeln("starting..."); - foo(); - writeln("finishing..."); + thread_scheduler tsched; + tsched.start([&tsched]() { + writeln("thread scheduler: starting..."); + foo(tsched); + writeln("thread scheduler: finishing..."); + }); + + simple_coroutine_scheduler scsched; + scsched.start([&scsched]() { + writeln("simple coroutine scheduler: starting..."); + foo(scsched); + writeln("simple coroutine scheduler: finishing..."); }); } /* -starting... +thread scheduler: starting... first half: 356 second half: -20 total: 336 -finishing... +thread scheduler: finishing... +simple coroutine scheduler: starting... +first half: 356 +second half: -20 +total: 336 +simple coroutine scheduler: finishing... */ diff --git a/ostd/channel.hh b/ostd/channel.hh index 3e0f383..cb74cf8 100644 --- a/ostd/channel.hh +++ b/ostd/channel.hh @@ -29,7 +29,7 @@ struct channel { * elides copying and moving (by directly returning the type ctor call) */ template - channel(F func): p_cond(func()) {} + channel(F func): p_lock(), p_cond(func(p_lock)) {} void put(T const &val) { put_impl(val); @@ -85,8 +85,8 @@ private: } std::list p_messages; - C p_cond; mutable std::mutex p_lock; + C p_cond; bool p_closed = false; }; diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index 0f6d881..7ce7297 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -8,7 +8,9 @@ #include #include +#include +#include "ostd/coroutine.hh" #include "ostd/channel.hh" namespace ostd { @@ -42,8 +44,8 @@ struct thread_scheduler { } template - channel make_channel() { - return channel{}; + std::shared_ptr> make_channel() { + return std::shared_ptr>{new channel{}}; } private: @@ -72,6 +74,122 @@ private: std::mutex p_lock; }; +struct simple_coroutine_scheduler { +private: + /* simple one just for channels */ + struct coro_cond { + coro_cond() = delete; + coro_cond(coro_cond const &) = delete; + coro_cond(coro_cond &&) = delete; + coro_cond &operator=(coro_cond const &) = delete; + coro_cond &operator=(coro_cond &&) = delete; + + coro_cond(simple_coroutine_scheduler &s, std::mutex &mtx): + p_sched(s), p_mtx(mtx) + {} + + template + void wait(L &l) noexcept { + l.unlock(); + while (!p_notified) { + p_sched.yield(); + } + p_notified = false; + l.lock(); + } + + void notify_one() noexcept { + p_notified = true; + p_mtx.unlock(); + p_sched.yield(); + p_mtx.lock(); + } + + void notify_all() noexcept { + p_notified = true; + p_mtx.unlock(); + p_sched.yield(); + p_mtx.lock(); + } + private: + simple_coroutine_scheduler &p_sched; + std::mutex &p_mtx; + bool p_notified = false; + }; + +public: + template + auto start(F &&func, A &&...args) -> std::result_of_t { + using R = std::result_of_t; + if constexpr(std::is_same_v) { + spawn(std::forward(func), std::forward(args)...); + dispatch(); + } else { + R ret; + spawn([lfunc = std::forward(func), &ret](A &&...args) { + ret = std::move(lfunc(std::forward(args)...)); + }, std::forward(args)...); + dispatch(); + return ret; + } + } + + template + void spawn(F &&func, A &&...args) { + if constexpr(sizeof...(A) == 0) { + p_coros.emplace_back([lfunc = std::forward(func)](auto) { + lfunc(); + }); + } else { + p_coros.emplace_back([lfunc = std::bind( + std::forward(func), std::forward(args)... + )](auto) { + lfunc(); + }); + } + } + + void yield() { + auto ctx = coroutine_context::current(); + coro *c = dynamic_cast(ctx); + if (c) { + coro::yield_type{*c}(); + return; + } + throw std::runtime_error{"no task to yield"}; + } + + template + std::shared_ptr> make_channel() { + return std::shared_ptr>{ + new channel{[this](std::mutex &mtx) { + return coro_cond{*this, mtx}; + }} + }; + } +private: + struct coro: coroutine { + using coroutine::coroutine; + }; + + void dispatch() { + while (!p_coros.empty()) { + if (p_idx == p_coros.end()) { + p_idx = p_coros.begin(); + } + (*p_idx)(); + if (!*p_idx) { + p_idx = p_coros.erase(p_idx); + } else { + ++p_idx; + } + } + } + + std::list p_coros; + typename std::list::iterator p_idx = p_coros.end(); +}; + } /* namespace ostd */ #endif