diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index 46734e5..e8a8f60 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -310,7 +310,7 @@ public: } ~coroutine() { - if (!p_func) { + if (this->p_state == detail::coroutine_context::state::TERM) { /* the stack has already unwound by a normal return */ return; } @@ -318,11 +318,11 @@ public: } explicit operator bool() const { - return bool(p_func); + return (this->p_state != detail::coroutine_context::state::TERM); } R resume(A ...args) { - if (!p_func) { + if (this->p_state == detail::coroutine_context::state::TERM) { throw coroutine_error{"dead coroutine"}; } return this->call(std::forward(args)...); @@ -342,6 +342,10 @@ private: static void context_call(detail::transfer_t t) { auto &self = *(static_cast(t.data)); self.p_orig = t.ctx; + if (self.p_state == detail::coroutine_context::state::INIT) { + /* we never got to execute properly */ + goto release; + } try { self.call_helper(self.p_func, std::index_sequence_for{}); } catch (detail::coroutine_context::forced_unwind v) { @@ -351,12 +355,10 @@ private: /* some other exception, will be rethrown later */ self.p_except = std::current_exception(); } - /* the func has fully finished here, so mark dead, stack - * will be freed by the coroutine's destructor later - */ - self.p_func = nullptr; - /* perform a last switch back to original context */ - self.yield_jump(); + /* switch back, release stack */ +release: + self.p_state = detail::coroutine_context::state::TERM; + self.finish(); } std::function p_func; diff --git a/ostd/internal/context.hh b/ostd/internal/context.hh index dba94dd..b1141b8 100644 --- a/ostd/internal/context.hh +++ b/ostd/internal/context.hh @@ -71,19 +71,20 @@ protected: forced_unwind(fcontext_t c): ctx(c) {} }; - coroutine_context() {} + enum class state { + INIT = 0, HOLD, EXEC, TERM + }; - ~coroutine_context() { - context_stack_free(p_stack); - } + coroutine_context() {} coroutine_context(coroutine_context const &) = delete; coroutine_context(coroutine_context &&c): p_stack(std::move(c.p_stack)), p_coro(c.p_coro), p_orig(c.p_orig), - p_except(std::move(c.p_except)) + p_except(std::move(c.p_except)), p_state(std::move(c.p_state)) { c.p_coro = c.p_orig = nullptr; c.p_stack = { nullptr, 0 }; + c.p_state = state::TERM; } coroutine_context &operator=(coroutine_context const &) = delete; @@ -93,6 +94,7 @@ protected: } void call() { + p_state = state::EXEC; coro_jump(); if (p_except) { std::rethrow_exception(std::move(p_except)); @@ -100,6 +102,14 @@ protected: } void unwind() { + if (p_state == state::INIT) { + /* this coroutine never got to live :( + * let it call the entry point at least this once... + * this will kill the stack so we don't leak memory + */ + coro_jump(); + return; + } ostd_ontop_fcontext( std::exchange(p_coro, nullptr), nullptr, [](transfer_t t) -> transfer_t { @@ -108,11 +118,20 @@ protected: ); } + void finish() { + ostd_ontop_fcontext(p_orig, this, [](transfer_t t) -> transfer_t { + auto &self = *(static_cast(t.data)); + context_stack_free(self.p_stack); + return { nullptr, nullptr }; + }); + } + void coro_jump() { p_coro = ostd_jump_fcontext(p_coro, this).ctx; } void yield_jump() { + p_state = state::HOLD; p_orig = ostd_jump_fcontext(p_orig, nullptr).ctx; } @@ -134,6 +153,7 @@ protected: fcontext_t p_coro; fcontext_t p_orig; std::exception_ptr p_except; + state p_state = state::INIT; }; /* stack allocator */