/* Coroutines for OctaSTD. * * This file is part of OctaSTD. See COPYING.md for further information. */ #ifndef OSTD_COROUTINE_HH #define OSTD_COROUTINE_HH #include #include #include #include #include #include "ostd/types.hh" #include "ostd/range.hh" #include "ostd/internal/context.hh" namespace ostd { struct coroutine_error: std::runtime_error { using std::runtime_error::runtime_error; }; template struct coroutine; namespace detail { /* like reference_wrapper but for any value */ template struct arg_wrapper { arg_wrapper() = default; arg_wrapper(T arg): p_arg(std::move(arg)) {} void operator=(T arg) { p_arg = std::move(arg); } operator T &&() { return std::move(p_arg); } void swap(arg_wrapper &other) { using std::swap; swap(p_arg, other.p_arg); } private: T p_arg = T{}; }; template struct arg_wrapper { arg_wrapper() = default; arg_wrapper(T &&arg): p_arg(&arg) {} void operator=(T &&arg) { p_arg = &arg; } operator T &&() { return *p_arg; } void swap(arg_wrapper &other) { using std::swap; swap(p_arg, other.p_arg); } private: T *p_arg = nullptr; }; template struct arg_wrapper { arg_wrapper() = default; arg_wrapper(T &arg): p_arg(&arg) {} void operator=(T &arg) { p_arg = &arg; } operator T &() { return *p_arg; } void swap(arg_wrapper &other) { using std::swap; swap(p_arg, other.p_arg); } private: T *p_arg = nullptr; }; template inline void swap(arg_wrapper &a, arg_wrapper &b) { a.swap(b); } template struct coro_types { using yield_type = std::tuple; }; template struct coro_types { using yield_type = A; }; template struct coro_types { using yield_type = std::pair; }; template<> struct coro_types<> { using yield_type = void; }; template using coro_args = typename coro_types::yield_type; template struct coro_yielder; /* default case, yield returns args and takes a value */ template struct coro_base: coroutine_context { protected: friend struct coro_yielder; template coro_args get_args(std::index_sequence) { if constexpr(sizeof...(A) == 0) { return; } else if constexpr(sizeof...(A) == 1) { return std::forward(std::get<0>(p_args)); } else if constexpr(sizeof...(A) == 2) { return std::make_pair(std::forward(std::get(p_args))...); } else { return std::move(p_args); } } template void call_helper(Y &&yielder, F &func, std::index_sequence) { p_result = std::forward(func( std::forward(yielder), std::forward(std::get(p_args))... )); } R call(A ...args) { p_args = std::make_tuple(arg_wrapper(std::forward(args))...); coroutine_context::call(); return std::forward(p_result); } void swap(coro_base &other) { using std::swap; swap(p_args, other.p_args); swap(p_result, other.p_result); coroutine_context::swap(other); } std::tuple...> p_args; arg_wrapper p_result; }; /* yield takes a value but doesn't return any args */ template struct coro_base: coroutine_context { protected: friend struct coro_yielder; template void get_args(std::index_sequence) {} template void call_helper(Y &&yielder, F &func, std::index_sequence) { p_result = std::forward(func(std::forward(yielder))); } R call() { coroutine_context::call(); return std::forward(this->p_result); } void swap(coro_base &other) { using std::swap; swap(p_result, other.p_result); coroutine_context::swap(other); } arg_wrapper p_result; }; /* yield doesn't take a value and returns args */ template struct coro_base: coroutine_context { protected: friend struct coro_yielder; template coro_args get_args(std::index_sequence) { if constexpr(sizeof...(A) == 0) { return; } else if constexpr(sizeof...(A) == 1) { return std::forward(std::get<0>(p_args)); } else if constexpr(sizeof...(A) == 2) { return std::make_pair(std::forward(std::get(p_args))...); } else { return std::move(p_args); } } template void call_helper(Y &&yielder, F &func, std::index_sequence) { func( std::forward(yielder), std::forward(std::get(p_args))... ); } void call(A ...args) { p_args = std::make_tuple(arg_wrapper(std::forward(args))...); coroutine_context::call(); } void swap(coro_base &other) { using std::swap; swap(p_args, other.p_args); coroutine_context::swap(other); } std::tuple...> p_args; }; /* yield doesn't take a value or return any args */ template<> struct coro_base: coroutine_context { protected: friend struct coro_yielder; template void get_args(std::index_sequence) {} template void call_helper(Y &&yielder, F &func, std::index_sequence) { func(std::forward(yielder)); } void call() { coroutine_context::call(); } void swap(coro_base &other) { coroutine_context::swap(other); } }; template struct coro_yielder { coro_yielder(coro_base &coro): p_coro(coro) {} coro_args operator()(R &&ret) { p_coro.p_result = std::forward(ret); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } coro_args operator()(R const &ret) { p_coro.p_result = ret; p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } private: coro_base &p_coro; }; template struct coro_yielder { coro_yielder(coro_base &coro): p_coro(coro) {} coro_args operator()(R &ret) { p_coro.p_result = ret; p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } private: coro_base &p_coro; }; template struct coro_yielder { coro_yielder(coro_base &coro): p_coro(coro) {} coro_args operator()(R &&ret) { p_coro.p_result = std::forward(ret); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } private: coro_base &p_coro; }; template struct coro_yielder { coro_yielder(coro_base &coro): p_coro(coro) {} coro_args operator()() { p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } private: coro_base &p_coro; }; } /* namespace detail */ template struct coroutine: detail::coro_base { private: using base_t = detail::coro_base; public: using yield_type = detail::coro_yielder; /* we have no way to assign a function anyway... */ coroutine() = delete; /* 0 means default size decided by the stack allocator */ template coroutine(F func, SA sa = SA{0}): base_t(), p_func(std::move(func)) { /* that way there is no context creation/stack allocation */ if (!p_func) { return; } this->make_context(sa, &context_call); } coroutine(coroutine const &) = delete; coroutine(coroutine &&c): detail::coro_base(std::move(c)), p_func(std::move(c.p_func)) { c.p_func = nullptr; } coroutine &operator=(coroutine const &) = delete; coroutine &operator=(coroutine &&c) { base_t::operator=(std::move(c)); p_func = std::move(c.p_func); c.p_func = nullptr; } ~coroutine() { if (this->p_state == detail::coroutine_context::state::TERM) { /* the stack has already unwound by a normal return */ return; } this->unwind(); } explicit operator bool() const { return (this->p_state != detail::coroutine_context::state::TERM); } R resume(A ...args) { if (this->p_state == detail::coroutine_context::state::TERM) { throw coroutine_error{"dead coroutine"}; } return this->call(std::forward(args)...); } R operator()(A ...args) { return resume(std::forward(args)...); } void swap(coroutine &other) { std::swap(p_func, other.p_func); base_t::swap(other); } private: /* the main entry point of the coroutine */ template static void context_call(detail::transfer_t t) { auto &self = *(static_cast(t.data)); self.p_orig = t.ctx; if (self.p_state == detail::coroutine_context::state::INIT) { /* we never got to execute properly */ goto release; } try { self.call_helper( yield_type{self}, self.p_func, std::index_sequence_for{} ); } catch (detail::coroutine_context::forced_unwind v) { /* forced_unwind is unique */ self.p_orig = v.ctx; } catch (...) { /* some other exception, will be rethrown later */ self.p_except = std::current_exception(); } /* switch back, release stack */ release: self.p_state = detail::coroutine_context::state::TERM; self.template finish(); } std::function p_func; }; template inline void swap(coroutine &a, coroutine &b) { a.swap(b); } template struct generator_range; template struct generator: detail::coroutine_context { private: struct yielder { yielder(generator &g): p_gen(g) {} void operator()(T &&ret) { p_gen.p_result = &ret; p_gen.yield_jump(); } void operator()(T &ret) { p_gen.p_result = &ret; p_gen.yield_jump(); } private: generator &p_gen; }; public: using range = generator_range; using yield_type = yielder; generator() = delete; template generator(F func, SA sa = SA{0}): p_func(std::move(func)) { /* that way there is no context creation/stack allocation */ if (!p_func) { return; } this->make_context(sa, &context_call); } generator(generator const &) = delete; generator(generator &&c): p_func(std::move(c.p_func)), p_result(c.p_result) { c.p_func = nullptr; c.p_result = nullptr; } generator &operator=(generator const &) = delete; generator &operator=(generator &&c) { p_func = std::move(c.p_func); p_result = c.p_result; c.p_func = nullptr; c.p_result = nullptr; } ~generator() { if (this->p_state == detail::coroutine_context::state::TERM) { return; } this->unwind(); } explicit operator bool() const { return (this->p_state != detail::coroutine_context::state::TERM); } void resume() { if (this->p_state == detail::coroutine_context::state::TERM) { throw coroutine_error{"dead generator"}; } detail::coroutine_context::call(); } T &value() { if (!p_result) { throw coroutine_error{"no value"}; } return *p_result; } T const &value() const { if (!p_result) { throw coroutine_error{"no value"}; } return *p_result; } bool empty() const { return (!p_result || this->p_state == detail::coroutine_context::state::TERM); } generator_range iter(); void swap(generator &other) { using std::swap; swap(p_func, other.p_func); swap(p_result, other.p_result); detail::coroutine_context::swap(other); } private: /* the main entry point of the generator */ template static void context_call(detail::transfer_t t) { auto &self = *(static_cast(t.data)); self.p_orig = t.ctx; if (self.p_state == detail::coroutine_context::state::INIT) { goto release; } try { self.p_func(yield_type{self}); } catch (detail::coroutine_context::forced_unwind v) { self.p_orig = v.ctx; } catch (...) { self.p_except = std::current_exception(); } release: self.p_state = detail::coroutine_context::state::TERM; self.p_result = nullptr; self.template finish(); } std::function p_func; /* we can use a pointer because even stack values are alive * as long as the coroutine is alive (and it is on every yield) */ T *p_result = nullptr; }; template inline void swap(generator &a, generator &b) { a.swap(b); } template struct generator_range: input_range> { using range_category = input_range_tag; using value_type = T; using reference = T &; using size_type = size_t; using difference_type = stream_off_t; generator_range() = delete; generator_range(generator &g): p_gen(&g) { pop_front(); } generator_range(generator_range const &r): p_gen(r.p_gen) {} bool empty() const { return p_gen->empty(); } void pop_front() { p_gen->resume(); } reference front() const { return p_gen->value(); } bool equals_front(generator_range const &g) const { return p_gen == g.p_gen; } private: generator *p_gen; }; template generator_range generator::iter() { return generator_range{*this}; } } /* namespace ostd */ #endif