diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index 7b36237..9b8eef4 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -128,20 +128,23 @@ public: basic_simple_coroutine_scheduler( size_t ss = TR::default_size(), size_t cs = basic_stack_pool::DEFAULT_CHUNK_SIZE - ): p_coros(), p_stacks(ss, cs) {} + ): + p_stacks(ss, cs), + p_dispatcher([this](auto yield_main) { + this->dispatch(yield_main); + }, p_stacks.get_allocator()), + p_coros() + {} template auto start(F &&func, A &&...args) -> std::result_of_t { using R = std::result_of_t; if constexpr(std::is_same_v) { - spawn(std::forward(func), std::forward(args)...); - dispatch(); + func(std::forward(args)...); + finish(); } else { - R ret; - spawn([lfunc = std::forward(func), &ret](A &&...args) { - ret = std::move(lfunc(std::forward(args)...)); - }, std::forward(args)...); - dispatch(); + auto ret = func(std::forward(args)...); + finish(); return ret; } } @@ -163,12 +166,18 @@ public: void yield() { auto ctx = coroutine_context::current(); + if (!ctx) { + /* yield from main means go to dispatcher and call first task */ + p_idx = p_coros.begin(); + p_dispatcher(); + return; + } coro *c = dynamic_cast(ctx); if (c) { typename coro::yield_type{*c}(); return; } - throw std::runtime_error{"no task to yield"}; + throw std::runtime_error{"attempt to yield outside coroutine"}; } template @@ -182,10 +191,15 @@ private: using coroutine::coroutine; }; - void dispatch() { + void dispatch(typename coro::yield_type &yield_main) { while (!p_coros.empty()) { if (p_idx == p_coros.end()) { - p_idx = p_coros.begin(); + /* we're at the end; it's time to return to main and + * continue there (potential yield from main results + * in continuing from this point with the first task) + */ + yield_main(); + continue; } (*p_idx)(); if (!*p_idx) { @@ -196,8 +210,19 @@ private: } } - std::list p_coros; + void finish() { + /* main has finished, but there might be either suspended or never + * started tasks in the queue; dispatch until there are none left + */ + while (!p_coros.empty()) { + p_idx = p_coros.begin(); + p_dispatcher(); + } + } + basic_stack_pool p_stacks; + coro p_dispatcher; + std::list p_coros; typename std::list::iterator p_idx = p_coros.end(); };