diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index d9beaab..a92719c 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -21,6 +21,24 @@ namespace ostd { struct scheduler { +private: + struct stack_allocator { + stack_allocator() = delete; + stack_allocator(scheduler &s) noexcept: p_sched(&s) {} + + stack_context allocate() { + return p_sched->allocate_stack(); + } + + void deaallocate(stack_context &st) noexcept { + p_sched->deallocate_stack(st); + } + + private: + scheduler *p_sched; + }; + +public: scheduler() {} scheduler(scheduler const &) = delete; @@ -32,6 +50,13 @@ struct scheduler { virtual void yield() noexcept = 0; virtual generic_condvar make_condition() = 0; + virtual stack_context allocate_stack() = 0; + virtual void deallocate_stack(stack_context &st) noexcept = 0; + + stack_allocator get_stack_allocator() noexcept { + return stack_allocator{*this}; + } + template void spawn(F &&func, A &&...args) { if constexpr(sizeof...(A) == 0) { @@ -49,6 +74,16 @@ struct scheduler { return make_condition(); }}; } + + template + coroutine make_coroutine(F &&func) { + return coroutine{std::forward(func), get_stack_allocator()}; + } + + template + generator make_generator(F &&func) { + return generator{std::forward(func), get_stack_allocator()}; + } }; namespace detail { @@ -108,6 +143,15 @@ struct thread_scheduler: scheduler { return generic_condvar{}; } + stack_context allocate_stack() { + /* TODO: store the allocator properly, for now it's fine */ + return fixedsize_stack{}.allocate(); + } + + void deallocate_stack(stack_context &st) noexcept { + fixedsize_stack{}.deallocate(st); + } + private: void remove_thread(typename std::list::iterator it) { std::lock_guard l{p_lock}; @@ -284,6 +328,15 @@ public: return coro_cond{*this}; }}; } + + stack_context allocate_stack() { + return p_stacks.allocate(); + } + + void deallocate_stack(stack_context &st) noexcept { + p_stacks.deallocate(st); + } + private: void dispatch() { while (!p_coros.empty()) { @@ -443,6 +496,15 @@ public: return task_cond{*this}; }}; } + + stack_context allocate_stack() { + return p_stacks.allocate(); + } + + void deallocate_stack(stack_context &st) noexcept { + p_stacks.deallocate(st); + } + private: template void spawn_add(SA &&sa, F &&func, A &&...args) { @@ -585,6 +647,16 @@ inline channel make_channel() { return detail::current_scheduler->make_channel(); } +template +coroutine make_coroutine(F &&func) { + return detail::current_scheduler->make_coroutine(std::forward(func)); +} + +template +generator make_generator(F &&func) { + return detail::current_scheduler->make_generator(std::forward(func)); +} + } /* namespace ostd */ #endif