diff --git a/examples/coroutine1.cc b/examples/coroutine1.cc index b29e8c4..ba4cb5f 100644 --- a/examples/coroutine1.cc +++ b/examples/coroutine1.cc @@ -37,13 +37,7 @@ int main() { int val = 5; for (int i: range(steps)) { writeln(" calling into coroutine..."); - int v; - try { - v = f(val); - } catch (coroutine_error const &e) { - writefln("coroutine error: %s", e.what()); - return 0; - } + auto v = f(val); writefln(" called into coroutine which yielded: %s", v); writefln(" call loop iteration %s done", i + 1); writefln(" coroutine dead: %s", !f); @@ -291,5 +285,6 @@ starting main... call loop iteration 6 done coroutine dead: true calling into coroutine... -coroutine error: dead coroutine +terminating with uncaught exception of type ostd::coroutine_error: dead coroutine +zsh: abort ./coro */ diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index 7248a21..2a9e456 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -267,14 +267,12 @@ namespace detail { coro_args operator()(R &&ret) { p_coro.p_result = std::forward(ret); - p_coro.set_hold(); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } coro_args operator()(R const &ret) { p_coro.p_result = ret; - p_coro.set_hold(); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } @@ -289,7 +287,6 @@ namespace detail { coro_args operator()(R &ret) { p_coro.p_result = ret; - p_coro.set_hold(); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } @@ -304,7 +301,6 @@ namespace detail { coro_args operator()(R &&ret) { p_coro.p_result = std::forward(ret); - p_coro.set_hold(); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } @@ -318,7 +314,6 @@ namespace detail { coro_yielder(coro_base &coro): p_coro(coro) {} coro_args operator()() { - p_coro.set_hold(); p_coro.yield_jump(); return p_coro.get_args(std::make_index_sequence{}); } @@ -424,8 +419,7 @@ private: } /* switch back, release stack */ release: - self.set_dead(); - self.yield_jump(); + self.template finish(); } std::function p_func; @@ -452,13 +446,11 @@ private: void operator()(T &&ret) { p_gen.p_result = &ret; - p_gen.set_hold(); p_gen.yield_jump(); } void operator()(T &ret) { p_gen.p_result = &ret; - p_gen.set_hold(); p_gen.yield_jump(); } private: @@ -574,8 +566,7 @@ private: } release: self.p_result = nullptr; - self.set_dead(); - self.yield_jump(); + self.template finish(); } std::function p_func; diff --git a/ostd/internal/context.hh b/ostd/internal/context.hh index b43d66e..0a754b3 100644 --- a/ostd/internal/context.hh +++ b/ostd/internal/context.hh @@ -56,12 +56,11 @@ protected: coroutine_context(coroutine_context &&c): p_stack(std::move(c.p_stack)), p_coro(c.p_coro), p_orig(c.p_orig), p_except(std::move(c.p_except)), p_state(std::move(c.p_state)), - p_sa(c.p_sa), p_free(c.p_free) + p_sa(c.p_sa) { c.p_coro = c.p_orig = nullptr; c.p_stack = { nullptr, 0 }; c.p_sa = nullptr; - c.p_free = nullptr; c.set_dead(); } @@ -80,21 +79,37 @@ protected: } void unwind() { - if (is_dead() || !p_orig) { - /* this coroutine is either done or never started */ - free_stack(); + if (is_dead()) { + /* this coroutine was either initialized with a null function or + * it's already terminated and thus its stack has already unwound + */ + return; + } + if (!p_orig) { + /* this coroutine never got to live :( + * let it call the entry point at least this once... + * this will kill the stack so we don't leak memory + */ + coro_jump(); return; } - /* this coroutine is suspended, and we need to make sure all - * the destructors for its locals get called by forcing unwind - */ ostd_ontop_fcontext( std::exchange(p_coro, nullptr), nullptr, [](transfer_t t) -> transfer_t { throw forced_unwind{t.ctx}; } ); - free_stack(); + } + + template + void finish() { + set_dead(); + ostd_ontop_fcontext(p_orig, this, [](transfer_t t) -> transfer_t { + auto &self = *(static_cast(t.data)); + auto &sa = *(static_cast(self.p_sa)); + sa.deallocate(self.p_stack); + return { nullptr, nullptr }; + }); } void coro_jump() { @@ -102,6 +117,7 @@ protected: } void yield_jump() { + p_state = state::HOLD; p_orig = ostd_jump_fcontext(p_orig, nullptr).ctx; } @@ -113,10 +129,6 @@ protected: return (p_state == state::TERM); } - void set_hold() { - p_state = state::HOLD; - } - void set_dead() { p_state = state::TERM; } @@ -129,7 +141,6 @@ protected: swap(p_except, other.p_except); swap(p_state, other.p_state); swap(p_sa, other.p_sa); - swap(p_free, other.p_free); } template @@ -146,20 +157,6 @@ protected: p_coro = ostd_make_fcontext(sp, asize, callp); p_sa = new (sp) SA(std::move(sa)); - p_free = &free_stack_call; - } - - template - static void free_stack_call(void *data) { - auto &self = *(static_cast(data)); - auto &sa = *(static_cast(self.p_sa)); - sa.deallocate(self.p_stack); - } - - void free_stack() { - if (p_free) { - p_free(this); - } } stack_context p_stack; @@ -168,7 +165,6 @@ protected: std::exception_ptr p_except; state p_state = state::HOLD; void *p_sa = nullptr; - void (*p_free)(void *) = nullptr; }; } /* namespace detail */