diff --git a/examples/coroutine2.cc b/examples/coroutine2.cc index 20c1f2f..54ee855 100644 --- a/examples/coroutine2.cc +++ b/examples/coroutine2.cc @@ -4,7 +4,7 @@ using namespace ostd; int main() { - generator g = [](auto yield) { + coroutine g = [](auto yield) { yield(5); yield(10); yield(15); @@ -13,8 +13,8 @@ int main() { }; writeln("generator test"); - for (int i: g) { - writeln("generated: ", i); + for (int i: g.iter()) { + writefln("generated: %s", i); } } diff --git a/ostd/coroutine.hh b/ostd/coroutine.hh index 99ae665..9292e52 100644 --- a/ostd/coroutine.hh +++ b/ostd/coroutine.hh @@ -181,6 +181,9 @@ inline void swap(coroutine_context &a, coroutine_context &b) { template struct coroutine; +template +struct coroutine_range; + namespace detail { /* like reference_wrapper but for any value */ template @@ -345,6 +348,8 @@ namespace detail { /* yield takes a value but doesn't return any args */ template struct coro_base { + coroutine_range iter(); + protected: coro_base(void (*callp)(void *), size_t ss): p_ctx(ss, callp, this) @@ -554,29 +559,27 @@ inline void swap(coroutine &a, coroutine &b) { } template -struct generator: input_range> { +struct coroutine_range: input_range> { using range_category = input_range_tag; using value_type = T; using reference = T &; using size_type = size_t; using difference_type = stream_off_t; - generator() = default; - - template - generator(F &&func, size_t ss = COROUTINE_DEFAULT_STACK_SIZE): - p_ptr(new coroutine{std::forward(func), ss}) - { + coroutine_range() = delete; + coroutine_range(coroutine &c): p_coro(&c) { pop_front(); } + coroutine_range(coroutine_range const &r): + p_coro(r.p_coro), p_item(r.p_item) {} bool empty() const { return !p_item; } void pop_front() { - if (*p_ptr) { - p_item = (*p_ptr)(); + if (*p_coro) { + p_item = (*p_coro)(); } else { p_item = std::nullopt; } @@ -586,14 +589,21 @@ struct generator: input_range> { return p_item.value(); } - bool equals_front(generator const &g) { - return p_ptr == g.p_ptr; + bool equals_front(coroutine_range const &g) { + return p_coro == g.p_coro; } private: - std::shared_ptr> p_ptr; + coroutine *p_coro; mutable std::optional p_item; }; +namespace detail { + template + coroutine_range coro_base::iter() { + return coroutine_range{static_cast &>(*this)}; + } +} + } /* namespace ostd */ #endif