/* Coroutines for OctaSTD. * * This file is part of OctaSTD. See COPYING.md for further information. */ #ifndef OSTD_COROUTINE_HH #define OSTD_COROUTINE_HH #include #include #include #include #include #include #include #include #include "ostd/types.hh" #include "ostd/platform.hh" #include "ostd/range.hh" /* from boost.context */ #ifdef OSTD_PLATFORM_WIN32 # if (defined(i386) || defined(__i386__) || defined(__i386) || \ defined(__i486__) || defined(__i586__) || defined(__i686__) || \ defined(__X86__) || defined(_X86_) || defined(__THW_INTEL__) || \ defined(__I86__) || defined(__INTEL__) || defined(__IA32__) || \ defined(_M_IX86) || defined(_I86_)) # define OSTD_CONTEXT_CDECL __cdecl # endif #endif #ifndef OSTD_CONTEXT_CDECL #define OSTD_CONTEXT_CDECL #endif namespace ostd { constexpr size_t COROUTINE_DEFAULT_STACK_SIZE = SIGSTKSZ; struct coroutine_error: std::runtime_error { using std::runtime_error::runtime_error; }; namespace detail { /* from boost.fcontext */ using fcontext_t = void *; struct transfer_t { fcontext_t ctx; void *data; }; extern "C" OSTD_EXPORT transfer_t OSTD_CONTEXT_CDECL ostd_jump_fcontext( fcontext_t const to, void *vp ); extern "C" OSTD_EXPORT fcontext_t OSTD_CONTEXT_CDECL ostd_make_fcontext( void *sp, size_t size, void (*fn)(transfer_t) ); extern "C" OSTD_EXPORT transfer_t OSTD_CONTEXT_CDECL ostd_ontop_fcontext( fcontext_t const to, void *vp, transfer_t (*fn)(transfer_t) ); struct forced_unwind { fcontext_t ctx; forced_unwind(fcontext_t c): ctx(c) {} }; } struct coroutine_context { coroutine_context(size_t ss, void (*callp)(void *), void *data): p_stack(new byte[ss]), p_callp(callp), p_data(data) { p_coro = detail::ostd_make_fcontext(p_stack.get() + ss, ss, context_call); } coroutine_context(coroutine_context const &) = delete; coroutine_context(coroutine_context &&c): p_stack(std::move(c.p_stack)), p_coro(c.p_coro), p_orig(c.p_orig), p_except(std::move(c.p_except)), p_callp(c.p_callp), p_data(c.p_data), p_finished(c.p_finished) { c.p_coro = c.p_orig = nullptr; c.p_data = nullptr; c.p_callp = nullptr; /* make sure it's not unwound */ c.p_finished = true; } coroutine_context &operator=(coroutine_context const &) = delete; coroutine_context &operator=(coroutine_context &&c) { swap(c); /* make sure it's not unwound */ c.p_finished = true; return *this; } void call() { if (p_finished) { throw coroutine_error{"dead coroutine"}; } coro_jump(); if (p_except) { std::rethrow_exception(std::move(p_except)); } } void unwind() { if (p_finished) { return; } detail::ostd_ontop_fcontext( std::exchange(p_coro, nullptr), nullptr, [](detail::transfer_t t) -> detail::transfer_t { throw detail::forced_unwind{t.ctx}; } ); } void coro_jump() { p_coro = detail::ostd_jump_fcontext(p_coro, this).ctx; } void yield_jump() { p_orig = detail::ostd_jump_fcontext(p_orig, nullptr).ctx; } bool is_done() const { return p_finished; } void set_data(void *data) { p_data = data; } void swap(coroutine_context &other) noexcept { std::swap(p_stack, other.p_stack); std::swap(p_coro, other.p_coro); std::swap(p_orig, other.p_orig); std::swap(p_except, other.p_except); std::swap(p_callp, other.p_callp); std::swap(p_data, other.p_data); std::swap(p_finished, other.p_finished); } private: static void context_call(detail::transfer_t t) { auto &self = *(static_cast(t.data)); self.p_orig = t.ctx; try { self.p_callp(self.p_data); } catch (detail::forced_unwind v) { self.p_orig = v.ctx; } catch (...) { self.p_except = std::current_exception(); } self.p_finished = true; self.yield_jump(); } /* TODO: new'ing the stack is sub-optimal */ std::unique_ptr p_stack; detail::fcontext_t p_coro; detail::fcontext_t p_orig; std::exception_ptr p_except; void (*p_callp)(void *); void *p_data; bool p_finished = false; }; inline void swap(coroutine_context &a, coroutine_context &b) { a.swap(b); } template struct coroutine; namespace detail { /* like reference_wrapper but for any value */ template struct arg_wrapper { arg_wrapper() = default; arg_wrapper(T arg): p_arg(std::move(arg)) {} void operator=(T arg) { p_arg = std::move(arg); } operator T &&() { return std::move(p_arg); } void swap(arg_wrapper &other) { std::swap(p_arg, other.p_arg); } private: T p_arg = T{}; }; template struct arg_wrapper { arg_wrapper() = default; arg_wrapper(T &&arg): p_arg(&arg) {} void operator=(T &&arg) { p_arg = &arg; } operator T &&() { return *p_arg; } void swap(arg_wrapper &other) { std::swap(p_arg, other.p_arg); } private: T *p_arg = nullptr; }; template struct arg_wrapper { arg_wrapper() = default; arg_wrapper(T &arg): p_arg(&arg) {} void operator=(T &arg) { p_arg = &arg; } operator T &() { return *p_arg; } void swap(arg_wrapper &other) { std::swap(p_arg, other.p_arg); } private: T *p_arg = nullptr; }; template inline void swap(arg_wrapper &a, arg_wrapper &b) { a.swap(b); } template struct coro_types { using yield_type = std::tuple; }; template struct coro_types { using yield_type = A; }; template struct coro_types { using yield_type = std::pair; }; template using coro_args = typename coro_types::yield_type; template inline coro_args yield_ret( std::tuple...> &args, std::index_sequence ) { if constexpr(sizeof...(A) == 1) { return std::forward(std::get<0>(args)); } else if constexpr(sizeof...(A) == 2) { return std::make_pair(std::forward(std::get(args))...); } else { return std::move(args); } } /* default case, yield returns args and takes a value */ template struct coro_base { protected: coro_base(void (*callp)(void *), size_t ss): p_ctx(ss, callp, this) {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c): p_args(std::move(c.p_args)), p_result(std::move(c.p_result)), p_ctx(std::move(c.p_ctx)) { p_ctx.set_data(this); } coro_base &operator=(coro_base const &) = delete; coro_base &operator=(coro_base &&c) { std::swap(p_args, c.p_args); std::swap(p_result, c.p_result); std::swap(p_ctx, c.p_ctx); p_ctx.set_data(this); return *this; } struct yielder { yielder(coro_base &coro): p_coro(coro) {} coro_args operator()(R &&ret) { p_coro.p_result = std::forward(ret); p_coro.p_ctx.yield_jump(); return yield_ret( p_coro.p_args, std::make_index_sequence{} ); } private: coro_base &p_coro; }; template void call_helper(F &func, std::index_sequence) { p_result = std::forward( func(yielder{*this}, std::forward(std::get(p_args))...) ); } R call(A ...args) { p_args = std::make_tuple(arg_wrapper(std::forward(args))...); p_ctx.call(); return std::forward(p_result); } void swap(coro_base &other) { std::swap(p_args, other.p_args); std::swap(p_result, other.p_result); std::swap(p_ctx, other.p_ctx); } std::tuple...> p_args; arg_wrapper p_result; coroutine_context p_ctx; }; /* yield takes a value but doesn't return any args */ template struct coro_base { protected: coro_base(void (*callp)(void *), size_t ss): p_ctx(ss, callp, this) {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c): p_result(std::move(c.p_result)), p_ctx(std::move(c.p_ctx)) { p_ctx.set_data(this); } coro_base &operator=(coro_base const &) = delete; coro_base &operator=(coro_base &&c) { std::swap(p_result, c.p_result); std::swap(p_ctx, c.p_ctx); p_ctx.set_data(this); return *this; } struct yielder { yielder(coro_base &coro): p_coro(coro) {} void operator()(R &&ret) { p_coro.p_result = std::forward(ret); p_coro.p_ctx.yield_jump(); } private: coro_base &p_coro; }; template void call_helper(F &func, std::index_sequence) { p_result = std::forward(func(yielder{*this})); } R call() { p_ctx.call(); return std::forward(this->p_result); } void swap(coro_base &other) { std::swap(p_result, other.p_result); std::swap(p_ctx, other.p_ctx); } arg_wrapper p_result; coroutine_context p_ctx; }; /* yield doesn't take a value and returns args */ template struct coro_base { protected: coro_base(void (*callp)(void *), size_t ss): p_ctx(ss, callp, this) {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c): p_args(std::move(c.p_args)), p_ctx(std::move(c.p_ctx)) { p_ctx.set_data(this); } coro_base &operator=(coro_base const &) = delete; coro_base &operator=(coro_base &&c) { std::swap(p_args, c.p_args); std::swap(p_ctx, c.p_ctx); p_ctx.set_data(this); return *this; } struct yielder { yielder(coro_base &coro): p_coro(coro) {} coro_args operator()() { p_coro.p_ctx.yield_jump(); return yield_ret( p_coro.p_args, std::make_index_sequence{} ); } private: coro_base &p_coro; }; template void call_helper(F &func, std::index_sequence) { func(yielder{*this}, std::forward(std::get(p_args))...); } void call(A ...args) { p_args = std::make_tuple(arg_wrapper(std::forward(args))...); p_ctx.call(); } void swap(coro_base &other) { std::swap(p_args, other.p_args); std::swap(p_ctx, other.p_ctx); } std::tuple...> p_args; coroutine_context p_ctx; }; /* yield doesn't take a value or return any args */ template<> struct coro_base { protected: coro_base(void (*callp)(void *), size_t ss): p_ctx(ss, callp, this) {} coro_base(coro_base const &) = delete; coro_base(coro_base &&c): p_ctx(std::move(c.p_ctx)) { p_ctx.set_data(this); } coro_base &operator=(coro_base const &) = delete; coro_base &operator=(coro_base &&c) { std::swap(p_ctx, c.p_ctx); p_ctx.set_data(this); return *this; } struct yielder { yielder(coro_base &coro): p_coro(coro) {} void operator()() { p_coro.p_ctx.yield_jump(); } private: coro_base &p_coro; }; template void call_helper(F &func, std::index_sequence) { func(yielder{*this}); } void call() { p_ctx.call(); } void swap(coro_base &other) { std::swap(p_ctx, other.p_ctx); } coroutine_context p_ctx; }; } /* namespace detail */ template struct coroutine: detail::coro_base { private: using base_t = detail::coro_base; public: using yield_type = typename detail::coro_base::yielder; template coroutine(F func, size_t ss = COROUTINE_DEFAULT_STACK_SIZE): detail::coro_base(&context_call, ss), p_func(std::move(func)) {} coroutine(coroutine const &) = delete; coroutine(coroutine &&) = default; coroutine &operator=(coroutine const &) = delete; coroutine &operator=(coroutine &&) = default; ~coroutine() { this->p_ctx.unwind(); } operator bool() const { return !this->p_ctx.is_done(); } R resume(A ...args) { return this->call(std::forward(args)...); } R operator()(A ...args) { return this->call(std::forward(args)...); } void swap(coroutine &other) { std::swap(p_func, other.p_func); base_t::swap(other); } private: static void context_call(void *data) { coroutine &self = *(static_cast(data)); self.call_helper(self.p_func, std::index_sequence_for{}); } std::function p_func; }; template inline void swap(coroutine &a, coroutine &b) { a.swap(b); } template struct generator: input_range> { using range_category = input_range_tag; using value_type = T; using reference = T &; using size_type = size_t; using difference_type = stream_off_t; generator() = default; template generator(F &&func, size_t ss = COROUTINE_DEFAULT_STACK_SIZE): p_ptr(new coroutine{std::forward(func), ss}) { pop_front(); } bool empty() const { return !p_item; } void pop_front() { if (*p_ptr) { p_item = (*p_ptr)(); } else { p_item = std::nullopt; } } reference front() const { return p_item.value(); } bool equals_front(generator const &g) { return p_ptr == g.p_ptr; } private: std::shared_ptr> p_ptr; mutable std::optional p_item; }; } /* namespace ostd */ #endif