diff --git a/examples/coroutine2.cc b/examples/coroutine2.cc index fcdf61c..18203dd 100644 --- a/examples/coroutine2.cc +++ b/examples/coroutine2.cc @@ -14,6 +14,22 @@ int main() { for (int i: generator{f}) { writefln("generated: %s", i); } + + generator x{f}; + /* coroutine_context exists as a base type for any coroutine */ + coroutine_context &ctx = x; + + writefln( + "generator is coroutine: %s", + bool(ctx.coroutine>()) + ); + writefln( + "generator is generator: %s", + bool(ctx.coroutine>()) + ); + + generator &gr = *ctx.coroutine>(); + writefln("value of cast back generator: %s", gr.value()); } /* diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index c7c3a9f..1370bab 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -254,11 +254,11 @@ namespace detail { } /* namespace detail */ template -struct coroutine: detail::coroutine_context { +struct coroutine: coroutine_context { private: - using base_t = detail::coroutine_context; + using base_t = coroutine_context; /* necessary so that context callback can access privates */ - friend struct detail::coroutine_context; + friend struct coroutine_context; template struct yielder { @@ -389,6 +389,20 @@ public: base_t::swap(other); } + std::type_info const &target_type() const { + return p_stor.p_func.target_type(); + } + + template + F *target() { return p_stor.p_func.target(); } + + template + F const *target() const { return p_stor.p_func.target(); } + + std::type_info const &coroutine_type() const { + return typeid(coroutine); + } + private: void resume_call() { p_stor.call_helper(*this); @@ -409,10 +423,10 @@ namespace detail { } template -struct generator: detail::coroutine_context { +struct generator: coroutine_context { private: - using base_t = detail::coroutine_context; - friend struct detail::coroutine_context; + using base_t = coroutine_context; + friend struct coroutine_context; template struct yielder { @@ -514,7 +528,7 @@ public: if (this->is_dead()) { throw coroutine_error{"dead generator"}; } - detail::coroutine_context::call(); + coroutine_context::call(); } T &value() { @@ -545,7 +559,21 @@ public: using std::swap; swap(p_func, other.p_func); swap(p_result, other.p_result); - detail::coroutine_context::swap(other); + coroutine_context::swap(other); + } + + std::type_info const &target_type() const { + return p_func.target_type(); + } + + template + F *target() { return p_func.target(); } + + template + F const *target() const { return p_func.target(); } + + std::type_info const &coroutine_type() const { + return typeid(generator); } private: diff --git a/ostd/internal/context.hh b/ostd/internal/context.hh index f719641..7a8bff3 100644 --- a/ostd/internal/context.hh +++ b/ostd/internal/context.hh @@ -8,6 +8,7 @@ #include #include +#include #include "ostd/types.hh" #include "ostd/platform.hh" @@ -16,30 +17,50 @@ namespace ostd { namespace detail { -/* from boost.fcontext */ -using fcontext_t = void *; + /* from boost.fcontext */ + using fcontext_t = void *; -struct transfer_t { - fcontext_t ctx; - void *data; -}; + struct transfer_t { + fcontext_t ctx; + void *data; + }; -extern "C" OSTD_EXPORT -transfer_t OSTD_CDECL ostd_jump_fcontext( - fcontext_t const to, void *vp -); + extern "C" OSTD_EXPORT + transfer_t OSTD_CDECL ostd_jump_fcontext( + fcontext_t const to, void *vp + ); -extern "C" OSTD_EXPORT -fcontext_t OSTD_CDECL ostd_make_fcontext( - void *sp, size_t size, void (*fn)(transfer_t) -); + extern "C" OSTD_EXPORT + fcontext_t OSTD_CDECL ostd_make_fcontext( + void *sp, size_t size, void (*fn)(transfer_t) + ); -extern "C" OSTD_EXPORT -transfer_t OSTD_CDECL ostd_ontop_fcontext( - fcontext_t const to, void *vp, transfer_t (*fn)(transfer_t) -); + extern "C" OSTD_EXPORT + transfer_t OSTD_CDECL ostd_ontop_fcontext( + fcontext_t const to, void *vp, transfer_t (*fn)(transfer_t) + ); + +} /* namespace detail */ struct coroutine_context { + virtual std::type_info const &coroutine_type() const = 0; + + template + T *coroutine() { + if (coroutine_type() == typeid(T)) { + return reinterpret_cast(this); + } + return nullptr; + } + + template + T const *coroutine() const { + if (coroutine_type() == typeid(T)) { + return reinterpret_cast(this); + } + return nullptr; + } + protected: coroutine_context() {} ~coroutine_context() { @@ -71,12 +92,12 @@ protected: } void coro_jump() { - p_coro = ostd_jump_fcontext(p_coro, this).ctx; + p_coro = detail::ostd_jump_fcontext(p_coro, this).ctx; } void yield_jump() { p_state = state::HOLD; - p_orig = ostd_jump_fcontext(p_orig, nullptr).ctx; + p_orig = detail::ostd_jump_fcontext(p_orig, nullptr).ctx; } bool is_hold() const { @@ -108,7 +129,7 @@ protected: size_t asize = p_stack.size - (static_cast(p_stack.ptr) - static_cast(sp)); - p_coro = ostd_make_fcontext(sp, asize, &context_call); + p_coro = detail::ostd_make_fcontext(sp, asize, &context_call); new (sp) SA(std::move(sa)); } @@ -131,8 +152,8 @@ private: } struct forced_unwind { - fcontext_t ctx; - forced_unwind(fcontext_t c): ctx(c) {} + detail::fcontext_t ctx; + forced_unwind(detail::fcontext_t c): ctx(c) {} }; enum class state { @@ -154,9 +175,9 @@ private: coro_jump(); return; } - ostd_ontop_fcontext( + detail::ostd_ontop_fcontext( std::exchange(p_coro, nullptr), nullptr, - [](transfer_t t) -> transfer_t { + [](detail::transfer_t t) -> detail::transfer_t { throw forced_unwind{t.ctx}; } ); @@ -165,15 +186,17 @@ private: template void finish() { set_dead(); - ostd_ontop_fcontext(p_orig, this, [](transfer_t t) -> transfer_t { - auto &self = *(static_cast(t.data)); - auto &sa = *(static_cast(self.get_stack_ptr())); - SA dsa{std::move(sa)}; - /* in case it holds any state that needs destroying */ - sa.~SA(); - dsa.deallocate(self.p_stack); - return { nullptr, nullptr }; - }); + detail::ostd_ontop_fcontext( + p_orig, this, [](detail::transfer_t t) -> detail::transfer_t { + auto &self = *(static_cast(t.data)); + auto &sa = *(static_cast(self.get_stack_ptr())); + SA dsa{std::move(sa)}; + /* in case it holds any state that needs destroying */ + sa.~SA(); + dsa.deallocate(self.p_stack); + return { nullptr, nullptr }; + } + ); } template @@ -188,7 +211,7 @@ private: } try { self.resume_call(); - } catch (detail::coroutine_context::forced_unwind v) { + } catch (coroutine_context::forced_unwind v) { /* forced_unwind is unique */ self.p_orig = v.ctx; } catch (...) { @@ -201,13 +224,12 @@ release: } stack_context p_stack; - fcontext_t p_coro; - fcontext_t p_orig; + detail::fcontext_t p_coro; + detail::fcontext_t p_orig; std::exception_ptr p_except; state p_state = state::HOLD; }; -} /* namespace detail */ } /* namespace ostd */ #endif