From 8ecaa338bc2485baddb7bf8ddc054e1682c4f0d6 Mon Sep 17 00:00:00 2001 From: q66 Date: Sat, 4 Mar 2017 18:25:09 +0100 Subject: [PATCH] add initial coroutine module (slow ucontext_t, POSIX only, WiP) --- ostd/coroutine.hh | 201 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 ostd/coroutine.hh diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh new file mode 100644 index 0000000..2b3c9f7 --- /dev/null +++ b/ostd/coroutine.hh @@ -0,0 +1,201 @@ +/* Coroutines for OctaSTD. + * + * This file is part of OctaSTD. See COPYING.md for further information. + */ + +#ifndef OSTD_COROUTINE_HH +#define OSTD_COROUTINE_HH + +/* currently there is only POSIX support using obsolete ucontext stuff... + * we will want to implement Windows support using its fibers and also + * lightweight custom context switching with handwritten asm where we + * want this to run, but ucontext will stay as a fallback + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "ostd/types.hh" + +namespace ostd { + +constexpr size_t COROUTINE_DEFAULT_STACK_SIZE = SIGSTKSZ; + +struct coroutine_error: std::runtime_error { + using std::runtime_error::runtime_error; +}; + +struct coroutine_context { + coroutine_context(size_t ss, void (*callp)(void *), void *data): + p_stack(new byte[ss]) + { + getcontext(&p_coro); + p_coro.uc_link = &p_orig; + p_coro.uc_stack.ss_sp = p_stack.get(); + p_coro.uc_stack.ss_size = ss; + using mcfp = void (*)(); + using cpfp = void (*)(void *); + if constexpr(sizeof(void *) > sizeof(int)) { + union intu { + struct { int p1, p2; }; + void *p; + cpfp fp; + }; + intu ud, uf; + ud.p = data; + uf.fp = callp; + using amcfp = void (*)(int, int, int, int); + amcfp mcf = [](int f1, int f2, int d1, int d2) -> void { + intu ud2, uf2; + uf2.p1 = f1, uf2.p2 = f2; + ud2.p1 = d1, ud2.p2 = d2; + (uf2.fp)(ud2.p); + }; + makecontext( + &p_coro, reinterpret_cast(mcf), 4, + uf.p1, uf.p2, ud.p1, ud.p2 + ); + } else { + using amcfp = void (*)(int, int); + amcfp mcf = [](int f1, int d1) { + reinterpret_cast(f1)(reinterpret_cast(d1)); + }; + makecontext(&p_coro, reinterpret_cast(mcf), 2, callp, data); + } + } + + void call() { + if (p_finished) { + throw coroutine_error{"dead coroutine"}; + } + swapcontext(&p_orig, &p_coro); + if (p_except) { + std::rethrow_exception(std::move(p_except)); + } + } + + void yield_swap() { + swapcontext(&p_coro, &p_orig); + } + + void set_eh() { + p_except = std::current_exception(); + } + + void set_done() { + p_finished = true; + } + + bool is_done() const { + return p_finished; + } + +private: + /* TODO: new'ing the stack is sub-optimal */ + std::unique_ptr p_stack; + ucontext_t p_coro; + ucontext_t p_orig; + std::exception_ptr p_except; + bool p_finished = false; +}; + +template +struct coroutine; + +namespace detail { + /* we need this because yield is specialized based on result */ + template + struct coro_base { + coro_base(void (*callp)(void *), size_t ss): + p_ctx(ss, callp, this) + {} + + std::tuple yield(R &&ret) { + p_result = std::forward(ret); + p_ctx.yield_swap(); + return std::move(p_args); + } + + protected: + R ctx_call(A ...args) { + p_args = std::forward_as_tuple(std::forward(args)...); + p_ctx.call(); + return std::forward(p_result); + } + + std::tuple p_args; + R p_result; + coroutine_context p_ctx; + }; + + template + struct coro_base { + coro_base(void (*callp)(void *), size_t ss): + p_ctx(ss, callp, this) + {} + + std::tuple yield() { + p_ctx.yield_swap(); + return std::move(p_args); + } + + protected: + void ctx_call(A ...args) { + p_args = std::forward_as_tuple(std::forward(args)...); + p_ctx.call(); + } + + std::tuple p_args; + coroutine_context p_ctx; + }; +} /* namespace detail */ + +template +struct coroutine: detail::coro_base { + coroutine( + std::function &, A...)> func, + size_t ss = COROUTINE_DEFAULT_STACK_SIZE + ): + detail::coro_base(&ctx_func, ss), p_func(std::move(func)) + {} + + operator bool() const { + return this->p_ctx.is_done(); + } + + R operator()(A ...args) { + return this->ctx_call(std::forward(args)...); + } +private: + template + R call(std::index_sequence) { + return p_func(*this, std::forward(std::get(this->p_args))...); + } + + static void ctx_func(void *data) { + coroutine &self = *(static_cast(data)); + try { + using indices = std::index_sequence_for; + if constexpr(std::is_same_v) { + self.call(indices{}); + } else { + self.p_result = self.call(indices{}); + } + } catch (...) { + self.p_ctx.set_eh(); + } + self.p_ctx.set_done(); + } + + std::function &, A...)> p_func; +}; + +} /* namespace ostd */ + +#endif