diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index 2b3c9f7..6d99b79 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -6,14 +6,7 @@ #ifndef OSTD_COROUTINE_HH #define OSTD_COROUTINE_HH -/* currently there is only POSIX support using obsolete ucontext stuff... - * we will want to implement Windows support using its fibers and also - * lightweight custom context switching with handwritten asm where we - * want this to run, but ucontext will stay as a fallback - */ - #include -#include #include #include @@ -22,6 +15,22 @@ #include #include "ostd/types.hh" +#include "ostd/platform.hh" + +/* from boost.context */ +#ifdef OSTD_PLATFORM_WIN32 +# if (defined(i386) || defined(__i386__) || defined(__i386) || \ + defined(__i486__) || defined(__i586__) || defined(__i686__) || \ + defined(__X86__) || defined(_X86_) || defined(__THW_INTEL__) || \ + defined(__I86__) || defined(__INTEL__) || defined(__IA32__) || \ + defined(_M_IX86) || defined(_I86_)) +# define OSTD_CONTEXT_CDECL __cdecl +# endif +#endif + +#ifndef OSTD_CONTEXT_CDECL +#define OSTD_CONTEXT_CDECL +#endif namespace ostd { @@ -31,65 +40,68 @@ struct coroutine_error: std::runtime_error { using std::runtime_error::runtime_error; }; +namespace detail { + /* from boost.fcontext */ + using fcontext_t = void *; + + struct transfer_t { + fcontext_t ctx; + void *data; + }; + + extern "C" OSTD_EXPORT + transfer_t OSTD_CONTEXT_CDECL ostd_jump_fcontext( + fcontext_t const to, void *vp + ); + + extern "C" OSTD_EXPORT + fcontext_t OSTD_CONTEXT_CDECL ostd_make_fcontext( + void *sp, size_t size, void (*fn)(transfer_t) + ); + + extern "C" OSTD_EXPORT + transfer_t OSTD_CONTEXT_CDECL ostd_ontop_fcontext( + fcontext_t const to, void *vp, transfer_t (*fn)(transfer_t) + ); + + struct forced_unwind { + fcontext_t ctx; + forced_unwind(fcontext_t c): ctx(c) {} + }; +} + struct coroutine_context { coroutine_context(size_t ss, void (*callp)(void *), void *data): - p_stack(new byte[ss]) + p_stack(new byte[ss]), p_callp(callp), p_data(data) { - getcontext(&p_coro); - p_coro.uc_link = &p_orig; - p_coro.uc_stack.ss_sp = p_stack.get(); - p_coro.uc_stack.ss_size = ss; - using mcfp = void (*)(); - using cpfp = void (*)(void *); - if constexpr(sizeof(void *) > sizeof(int)) { - union intu { - struct { int p1, p2; }; - void *p; - cpfp fp; - }; - intu ud, uf; - ud.p = data; - uf.fp = callp; - using amcfp = void (*)(int, int, int, int); - amcfp mcf = [](int f1, int f2, int d1, int d2) -> void { - intu ud2, uf2; - uf2.p1 = f1, uf2.p2 = f2; - ud2.p1 = d1, ud2.p2 = d2; - (uf2.fp)(ud2.p); - }; - makecontext( - &p_coro, reinterpret_cast(mcf), 4, - uf.p1, uf.p2, ud.p1, ud.p2 - ); - } else { - using amcfp = void (*)(int, int); - amcfp mcf = [](int f1, int d1) { - reinterpret_cast(f1)(reinterpret_cast(d1)); - }; - makecontext(&p_coro, reinterpret_cast(mcf), 2, callp, data); - } + p_coro = detail::ostd_make_fcontext(p_stack.get() + ss, ss, context_call); } void call() { if (p_finished) { throw coroutine_error{"dead coroutine"}; } - swapcontext(&p_orig, &p_coro); + auto t = detail::ostd_jump_fcontext(p_coro, this); + p_coro = t.ctx; if (p_except) { std::rethrow_exception(std::move(p_except)); } } - void yield_swap() { - swapcontext(&p_coro, &p_orig); + void unwind() { + if (p_finished) { + return; + } + detail::ostd_ontop_fcontext( + std::exchange(p_coro, nullptr), nullptr, + [](detail::transfer_t t) -> detail::transfer_t { + throw detail::forced_unwind{t.ctx}; + } + ); } - void set_eh() { - p_except = std::current_exception(); - } - - void set_done() { - p_finished = true; + void yield_jump() { + p_coro = detail::ostd_jump_fcontext(p_orig, nullptr).ctx; } bool is_done() const { @@ -97,11 +109,27 @@ struct coroutine_context { } private: + static void context_call(detail::transfer_t t) { + auto &self = *(static_cast(t.data)); + self.p_orig = t.ctx; + try { + self.p_callp(self.p_data); + } catch (detail::forced_unwind v) { + self.p_orig = v.ctx; + } catch (...) { + self.p_except = std::current_exception(); + } + self.p_finished = true; + self.yield_jump(); + } + /* TODO: new'ing the stack is sub-optimal */ std::unique_ptr p_stack; - ucontext_t p_coro; - ucontext_t p_orig; + detail::fcontext_t p_coro; + detail::fcontext_t p_orig; std::exception_ptr p_except; + void (*p_callp)(void *); + void *p_data; bool p_finished = false; }; @@ -118,12 +146,12 @@ namespace detail { std::tuple yield(R &&ret) { p_result = std::forward(ret); - p_ctx.yield_swap(); + p_ctx.yield_jump(); return std::move(p_args); } protected: - R ctx_call(A ...args) { + R call(A ...args) { p_args = std::forward_as_tuple(std::forward(args)...); p_ctx.call(); return std::forward(p_result); @@ -141,12 +169,12 @@ namespace detail { {} std::tuple yield() { - p_ctx.yield_swap(); + p_ctx.yield_jump(); return std::move(p_args); } protected: - void ctx_call(A ...args) { + void call(A ...args) { p_args = std::forward_as_tuple(std::forward(args)...); p_ctx.call(); } @@ -162,35 +190,34 @@ struct coroutine: detail::coro_base { std::function &, A...)> func, size_t ss = COROUTINE_DEFAULT_STACK_SIZE ): - detail::coro_base(&ctx_func, ss), p_func(std::move(func)) + detail::coro_base(&context_call, ss), p_func(std::move(func)) {} + ~coroutine() { + this->p_ctx.unwind(); + } + operator bool() const { return this->p_ctx.is_done(); } R operator()(A ...args) { - return this->ctx_call(std::forward(args)...); + return this->call(std::forward(args)...); } private: template - R call(std::index_sequence) { + R call_helper(std::index_sequence) { return p_func(*this, std::forward(std::get(this->p_args))...); } - static void ctx_func(void *data) { + static void context_call(void *data) { + using indices = std::index_sequence_for; coroutine &self = *(static_cast(data)); - try { - using indices = std::index_sequence_for; - if constexpr(std::is_same_v) { - self.call(indices{}); - } else { - self.p_result = self.call(indices{}); - } - } catch (...) { - self.p_ctx.set_eh(); + if constexpr(std::is_same_v) { + self.call_helper(indices{}); + } else { + self.p_result = self.call_helper(indices{}); } - self.p_ctx.set_done(); } std::function &, A...)> p_func;