diff --git a/examples/coroutine2.cc b/examples/coroutine2.cc index 54ee855..eac0406 100644 --- a/examples/coroutine2.cc +++ b/examples/coroutine2.cc @@ -4,12 +4,10 @@ using namespace ostd; int main() { - coroutine g = [](auto yield) { - yield(5); - yield(10); - yield(15); - yield(20); - return 25; + generator g = [](auto yield) -> void { + for (int i: range(5, 26, 5)) { + yield(i); + } }; writeln("generator test"); diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index a308fe0..47b37b1 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -26,9 +26,6 @@ struct coroutine_error: std::runtime_error { template struct coroutine; -template -struct coroutine_range; - namespace detail { /* like reference_wrapper but for any value */ template @@ -114,45 +111,42 @@ namespace detail { using yield_type = std::pair; }; + template<> + struct coro_types<> { + using yield_type = void; + }; + template using coro_args = typename coro_types::yield_type; - template - inline coro_args yield_ret( - std::tuple...> &args, std::index_sequence - ) { - if constexpr(sizeof...(A) == 1) { - return std::forward(std::get<0>(args)); - } else if constexpr(sizeof...(A) == 2) { - return std::make_pair(std::forward(std::get(args))...); - } else { - return std::move(args); - } - } + template + struct coro_yielder; /* default case, yield returns args and takes a value */ template struct coro_base: coroutine_context { protected: - struct yielder { - yielder(coro_base &coro): p_coro(coro) {} + friend struct coro_yielder; - coro_args operator()(R &&ret) { - p_coro.p_result = std::forward(ret); - p_coro.yield_jump(); - return yield_ret( - p_coro.p_args, std::make_index_sequence{} - ); + template + coro_args get_args(std::index_sequence) { + if constexpr(sizeof...(A) == 0) { + return; + } else if constexpr(sizeof...(A) == 1) { + return std::forward(std::get<0>(p_args)); + } else if constexpr(sizeof...(A) == 2) { + return std::make_pair(std::forward(std::get(p_args))...); + } else { + return std::move(p_args); } - private: - coro_base &p_coro; - }; + } - template - void call_helper(F &func, std::index_sequence) { - p_result = std::forward( - func(yielder{*this}, std::forward(std::get(p_args))...) - ); + template + void call_helper(Y &&yielder, F &func, std::index_sequence) { + p_result = std::forward(func( + std::forward(yielder), + std::forward(std::get(p_args))... + )); } R call(A ...args) { @@ -175,23 +169,15 @@ namespace detail { /* yield takes a value but doesn't return any args */ template struct coro_base: coroutine_context { - coroutine_range iter(); - protected: - struct yielder { - yielder(coro_base &coro): p_coro(coro) {} + friend struct coro_yielder; - void operator()(R &&ret) { - p_coro.p_result = std::forward(ret); - p_coro.yield_jump(); - } - private: - coro_base &p_coro; - }; + template + void get_args(std::index_sequence) {} - template - void call_helper(F &func, std::index_sequence) { - p_result = std::forward(func(yielder{*this})); + template + void call_helper(Y &&yielder, F &func, std::index_sequence) { + p_result = std::forward(func(std::forward(yielder))); } R call() { @@ -212,22 +198,27 @@ namespace detail { template struct coro_base: coroutine_context { protected: - struct yielder { - yielder(coro_base &coro): p_coro(coro) {} + friend struct coro_yielder; - coro_args operator()() { - p_coro.yield_jump(); - return yield_ret( - p_coro.p_args, std::make_index_sequence{} - ); + template + coro_args get_args(std::index_sequence) { + if constexpr(sizeof...(A) == 0) { + return; + } else if constexpr(sizeof...(A) == 1) { + return std::forward(std::get<0>(p_args)); + } else if constexpr(sizeof...(A) == 2) { + return std::make_pair(std::forward(std::get(p_args))...); + } else { + return std::move(p_args); } - private: - coro_base &p_coro; - }; + } - template - void call_helper(F &func, std::index_sequence) { - func(yielder{*this}, std::forward(std::get(p_args))...); + template + void call_helper(Y &&yielder, F &func, std::index_sequence) { + func( + std::forward(yielder), + std::forward(std::get(p_args))... + ); } void call(A ...args) { @@ -248,19 +239,14 @@ namespace detail { template<> struct coro_base: coroutine_context { protected: - struct yielder { - yielder(coro_base &coro): p_coro(coro) {} + friend struct coro_yielder; - void operator()() { - p_coro.yield_jump(); - } - private: - coro_base &p_coro; - }; + template + void get_args(std::index_sequence) {} - template - void call_helper(F &func, std::index_sequence) { - func(yielder{*this}); + template + void call_helper(Y &&yielder, F &func, std::index_sequence) { + func(std::forward(yielder)); } void call() { @@ -271,6 +257,67 @@ namespace detail { coroutine_context::swap(other); } }; + + template + struct coro_yielder { + coro_yielder(coro_base &coro): p_coro(coro) {} + + coro_args operator()(R &&ret) { + p_coro.p_result = std::forward(ret); + p_coro.yield_jump(); + return p_coro.get_args(std::make_index_sequence{}); + } + + coro_args operator()(R const &ret) { + p_coro.p_result = ret; + p_coro.yield_jump(); + return p_coro.get_args(std::make_index_sequence{}); + } + + private: + coro_base &p_coro; + }; + + template + struct coro_yielder { + coro_yielder(coro_base &coro): p_coro(coro) {} + + coro_args operator()(R &ret) { + p_coro.p_result = ret; + p_coro.yield_jump(); + return p_coro.get_args(std::make_index_sequence{}); + } + + private: + coro_base &p_coro; + }; + + template + struct coro_yielder { + coro_yielder(coro_base &coro): p_coro(coro) {} + + coro_args operator()(R &&ret) { + p_coro.p_result = std::forward(ret); + p_coro.yield_jump(); + return p_coro.get_args(std::make_index_sequence{}); + } + + private: + coro_base &p_coro; + }; + + template + struct coro_yielder { + coro_yielder(coro_base &coro): p_coro(coro) {} + + coro_args operator()() { + p_coro.yield_jump(); + return p_coro.get_args(std::make_index_sequence{}); + } + + private: + coro_base &p_coro; + }; } /* namespace detail */ template @@ -279,7 +326,7 @@ private: using base_t = detail::coro_base; public: - using yield_type = typename detail::coro_base::yielder; + using yield_type = detail::coro_yielder; /* we have no way to assign a function anyway... */ coroutine() = delete; @@ -349,7 +396,9 @@ private: goto release; } try { - self.call_helper(self.p_func, std::index_sequence_for{}); + self.call_helper( + yield_type{self}, self.p_func, std::index_sequence_for{} + ); } catch (detail::coroutine_context::forced_unwind v) { /* forced_unwind is unique */ self.p_orig = v.ctx; @@ -371,50 +420,174 @@ inline void swap(coroutine &a, coroutine &b) { a.swap(b); } +template struct generator_range; + template -struct coroutine_range: input_range> { +struct generator: detail::coroutine_context { +private: + struct yielder { + yielder(generator &g): p_gen(g) {} + + void operator()(T &&ret) { + p_gen.p_result = &ret; + p_gen.yield_jump(); + } + + void operator()(T &ret) { + p_gen.p_result = &ret; + p_gen.yield_jump(); + } + private: + generator &p_gen; + }; + +public: + using range = generator_range; + + using yield_type = yielder; + + generator() = delete; + + template + generator(F func, SA sa = SA{0}): p_func(std::move(func)) { + /* that way there is no context creation/stack allocation */ + if (!p_func) { + return; + } + this->make_context(sa, &context_call); + } + + generator(generator const &) = delete; + generator(generator &&c): + p_func(std::move(c.p_func)), p_result(c.p_result) + { + c.p_func = nullptr; + c.p_result = nullptr; + } + + generator &operator=(generator const &) = delete; + generator &operator=(generator &&c) { + p_func = std::move(c.p_func); + p_result = c.p_result; + c.p_func = nullptr; + c.p_result = nullptr; + } + + ~generator() { + if (this->p_state == detail::coroutine_context::state::TERM) { + return; + } + this->unwind(); + } + + explicit operator bool() const { + return (this->p_state != detail::coroutine_context::state::TERM); + } + + void resume() { + if (this->p_state == detail::coroutine_context::state::TERM) { + throw coroutine_error{"dead generator"}; + } + detail::coroutine_context::call(); + } + + T &value() { + if (!p_result) { + throw coroutine_error{"no value"}; + } + return *p_result; + } + + T const &value() const { + if (!p_result) { + throw coroutine_error{"no value"}; + } + return *p_result; + } + + bool empty() const { + return (!p_result || this->p_state == detail::coroutine_context::state::TERM); + } + + generator_range iter(); + + void swap(generator &other) { + using std::swap; + swap(p_func, other.p_func); + swap(p_result, other.p_result); + detail::coroutine_context::swap(other); + } + +private: + /* the main entry point of the generator */ + template + static void context_call(detail::transfer_t t) { + auto &self = *(static_cast(t.data)); + self.p_orig = t.ctx; + if (self.p_state == detail::coroutine_context::state::INIT) { + goto release; + } + try { + self.p_func(yield_type{self}); + } catch (detail::coroutine_context::forced_unwind v) { + self.p_orig = v.ctx; + } catch (...) { + self.p_except = std::current_exception(); + } +release: + self.p_state = detail::coroutine_context::state::TERM; + self.p_result = nullptr; + self.template finish(); + } + + std::function p_func; + /* we can use a pointer because even stack values are alive + * as long as the coroutine is alive (and it is on every yield) + */ + T *p_result = nullptr; +}; + +template +inline void swap(generator &a, generator &b) { + a.swap(b); +} + +template +struct generator_range: input_range> { using range_category = input_range_tag; using value_type = T; using reference = T &; using size_type = size_t; using difference_type = stream_off_t; - coroutine_range() = delete; - coroutine_range(coroutine &c): p_coro(&c) { + generator_range() = delete; + generator_range(generator &g): p_gen(&g) { pop_front(); } - coroutine_range(coroutine_range const &r): - p_coro(r.p_coro), p_item(r.p_item) {} + generator_range(generator_range const &r): p_gen(r.p_gen) {} bool empty() const { - return !p_item; + return p_gen->empty(); } void pop_front() { - if (*p_coro) { - p_item = (*p_coro)(); - } else { - p_item = std::nullopt; - } + p_gen->resume(); } reference front() const { - return p_item.value(); + return p_gen->value(); } - bool equals_front(coroutine_range const &g) const { - return p_coro == g.p_coro; + bool equals_front(generator_range const &g) const { + return p_gen == g.p_gen; } private: - coroutine *p_coro; - mutable std::optional p_item; + generator *p_gen; }; -namespace detail { - template - coroutine_range coro_base::iter() { - return coroutine_range{static_cast &>(*this)}; - } +template +generator_range generator::iter() { + return generator_range{*this}; } } /* namespace ostd */