diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index 2428947..36625e7 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -43,6 +43,12 @@ struct coroutine_error: std::runtime_error { using std::runtime_error::runtime_error; }; +template +struct coroutine; + +template +struct coroutine_range; + namespace detail { /* from boost.fcontext */ using fcontext_t = void *; @@ -73,19 +79,15 @@ namespace detail { }; struct coroutine_context { - coroutine_context(size_t ss, void (*callp)(void *)): - p_stack(new byte[ss]), p_callp(callp) - { - p_coro = ostd_make_fcontext(p_stack.get() + ss, ss, context_call); - } + protected: + coroutine_context() {} coroutine_context(coroutine_context const &) = delete; coroutine_context(coroutine_context &&c): p_stack(std::move(c.p_stack)), p_coro(c.p_coro), p_orig(c.p_orig), - p_except(std::move(c.p_except)), p_callp(c.p_callp) + p_except(std::move(c.p_except)) { c.p_coro = c.p_orig = nullptr; - c.p_callp = nullptr; } coroutine_context &operator=(coroutine_context const &) = delete; @@ -123,21 +125,11 @@ namespace detail { std::swap(p_coro, other.p_coro); std::swap(p_orig, other.p_orig); std::swap(p_except, other.p_except); - std::swap(p_callp, other.p_callp); } - private: - static void context_call(transfer_t t) { - auto &self = *(static_cast(t.data)); - self.p_orig = t.ctx; - try { - self.p_callp(t.data); - } catch (forced_unwind v) { - self.p_orig = v.ctx; - } catch (...) { - self.p_except = std::current_exception(); - } - self.yield_jump(); + void make_context(size_t ss, void (*callp)(transfer_t)) { + p_stack = std::unique_ptr{new byte[ss]}; + p_coro = ostd_make_fcontext(p_stack.get() + ss, ss, callp); } /* TODO: new'ing the stack is sub-optimal */ @@ -145,17 +137,8 @@ namespace detail { fcontext_t p_coro; fcontext_t p_orig; std::exception_ptr p_except; - void (*p_callp)(void *); }; -} -template -struct coroutine; - -template -struct coroutine_range; - -namespace detail { /* like reference_wrapper but for any value */ template struct arg_wrapper { @@ -255,11 +238,9 @@ namespace detail { /* default case, yield returns args and takes a value */ template - struct coro_base: detail::coroutine_context { + struct coro_base: coroutine_context { protected: - coro_base(void (*callp)(void *), size_t ss): - detail::coroutine_context(ss, callp) - {} + coro_base() {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c) = default; @@ -290,14 +271,14 @@ namespace detail { R call(A ...args) { p_args = std::make_tuple(arg_wrapper(std::forward(args))...); - detail::coroutine_context::call(); + coroutine_context::call(); return std::forward(p_result); } void swap(coro_base &other) { std::swap(p_args, other.p_args); std::swap(p_result, other.p_result); - detail::coroutine_context::swap(other); + coroutine_context::swap(other); } std::tuple...> p_args; @@ -306,13 +287,11 @@ namespace detail { /* yield takes a value but doesn't return any args */ template - struct coro_base: detail::coroutine_context { + struct coro_base: coroutine_context { coroutine_range iter(); protected: - coro_base(void (*callp)(void *), size_t ss): - detail::coroutine_context(ss, callp) - {} + coro_base() {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c) = default; @@ -337,13 +316,13 @@ namespace detail { } R call() { - detail::coroutine_context::call(); + coroutine_context::call(); return std::forward(this->p_result); } void swap(coro_base &other) { std::swap(p_result, other.p_result); - detail::coroutine_context::swap(other); + coroutine_context::swap(other); } arg_wrapper p_result; @@ -351,11 +330,9 @@ namespace detail { /* yield doesn't take a value and returns args */ template - struct coro_base: detail::coroutine_context { + struct coro_base: coroutine_context { protected: - coro_base(void (*callp)(void *), size_t ss): - detail::coroutine_context(ss, callp) - {} + coro_base() {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c) = default; @@ -383,12 +360,12 @@ namespace detail { void call(A ...args) { p_args = std::make_tuple(arg_wrapper(std::forward(args))...); - detail::coroutine_context::call(); + coroutine_context::call(); } void swap(coro_base &other) { std::swap(p_args, other.p_args); - detail::coroutine_context::swap(other); + coroutine_context::swap(other); } std::tuple...> p_args; @@ -396,11 +373,9 @@ namespace detail { /* yield doesn't take a value or return any args */ template<> - struct coro_base: detail::coroutine_context { + struct coro_base: coroutine_context { protected: - coro_base(void (*callp)(void *), size_t ss): - detail::coroutine_context(ss, callp) - {} + coro_base() {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c) = default; @@ -424,11 +399,11 @@ namespace detail { } void call() { - detail::coroutine_context::call(); + coroutine_context::call(); } void swap(coro_base &other) { - detail::coroutine_context::swap(other); + coroutine_context::swap(other); } }; } /* namespace detail */ @@ -443,8 +418,13 @@ public: template coroutine(F func, size_t ss = COROUTINE_DEFAULT_STACK_SIZE): - detail::coro_base(&context_call, ss), p_func(std::move(func)) - {} + p_func(std::move(func)) + { + if (!p_func) { + return; + } + this->make_context(ss, &context_call); + } coroutine(coroutine const &) = delete; coroutine(coroutine &&c): @@ -489,10 +469,18 @@ public: } private: - static void context_call(void *data) { - coroutine &self = *(static_cast(data)); - self.call_helper(self.p_func, std::index_sequence_for{}); + static void context_call(detail::transfer_t t) { + auto &self = *(static_cast(t.data)); + self.p_orig = t.ctx; + try { + self.call_helper(self.p_func, std::index_sequence_for{}); + } catch (detail::forced_unwind v) { + self.p_orig = v.ctx; + } catch (...) { + self.p_except = std::current_exception(); + } self.p_func = nullptr; + self.yield_jump(); } std::function p_func;