/* Concurrency module with custom scheduler support. * * This file is part of OctaSTD. See COPYING.md for futher information. */ #ifndef OSTD_CONCURRENCY_HH #define OSTD_CONCURRENCY_HH #include #include #include #include "ostd/coroutine.hh" #include "ostd/channel.hh" namespace ostd { struct thread_scheduler { template using channel_type = channel; ~thread_scheduler() { join_all(); } template auto start(F &&func, A &&...args) -> std::result_of_t { return func(std::forward(args)...); } template void spawn(F &&func, A &&...args) { std::unique_lock l{p_lock}; p_threads.emplace_front(); auto it = p_threads.begin(); *it = std::thread{ [this, it](auto func, auto ...args) { func(std::move(args)...); this->remove_thread(it); }, std::forward(func), std::forward(args)... }; } void yield() { std::this_thread::yield(); } template channel make_channel() { return channel{}; } private: void remove_thread(typename std::list::iterator it) { std::unique_lock l{p_lock}; std::thread t{std::exchange(p_dead, std::move(*it))}; if (t.joinable()) { t.join(); } p_threads.erase(it); } void join_all() { /* wait for all threads to finish */ std::unique_lock l{p_lock}; if (p_dead.joinable()) { p_dead.join(); } for (auto &t: p_threads) { t.join(); } } std::list p_threads; std::thread p_dead; std::mutex p_lock; }; struct simple_coroutine_scheduler { private: /* simple one just for channels */ struct coro_cond { coro_cond() = delete; coro_cond(coro_cond const &) = delete; coro_cond(coro_cond &&) = delete; coro_cond &operator=(coro_cond const &) = delete; coro_cond &operator=(coro_cond &&) = delete; coro_cond(simple_coroutine_scheduler &s, std::mutex &mtx): p_sched(s), p_mtx(mtx) {} template void wait(L &l) noexcept { l.unlock(); while (!p_notified) { p_sched.yield(); } p_notified = false; l.lock(); } void notify_one() noexcept { p_notified = true; p_mtx.unlock(); p_sched.yield(); p_mtx.lock(); } void notify_all() noexcept { p_notified = true; p_mtx.unlock(); p_sched.yield(); p_mtx.lock(); } private: simple_coroutine_scheduler &p_sched; std::mutex &p_mtx; bool p_notified = false; }; public: template using channel_type = channel; template auto start(F &&func, A &&...args) -> std::result_of_t { using R = std::result_of_t; if constexpr(std::is_same_v) { spawn(std::forward(func), std::forward(args)...); dispatch(); } else { R ret; spawn([lfunc = std::forward(func), &ret](A &&...args) { ret = std::move(lfunc(std::forward(args)...)); }, std::forward(args)...); dispatch(); return ret; } } template void spawn(F &&func, A &&...args) { if constexpr(sizeof...(A) == 0) { p_coros.emplace_back([lfunc = std::forward(func)](auto) { lfunc(); }); } else { p_coros.emplace_back([lfunc = std::bind( std::forward(func), std::forward(args)... )](auto) mutable { lfunc(); }); } } void yield() { auto ctx = coroutine_context::current(); coro *c = dynamic_cast(ctx); if (c) { coro::yield_type{*c}(); return; } throw std::runtime_error{"no task to yield"}; } template channel make_channel() { return channel{[this](std::mutex &mtx) { return coro_cond{*this, mtx}; }}; } private: struct coro: coroutine { using coroutine::coroutine; }; void dispatch() { while (!p_coros.empty()) { if (p_idx == p_coros.end()) { p_idx = p_coros.begin(); } (*p_idx)(); if (!*p_idx) { p_idx = p_coros.erase(p_idx); } else { ++p_idx; } } } std::list p_coros; typename std::list::iterator p_idx = p_coros.end(); }; template inline void spawn(S &sched, F &&func, A &&...args) { sched.spawn(std::forward(func), std::forward(args)...); } template inline void yield(S &sched) { sched.yield(); } template inline auto make_channel(S &sched) -> typename S::template channel_type { return sched.template make_channel(); } } /* namespace ostd */ #endif