diff --git a/ostd/channel.hh b/ostd/channel.hh index 26817f5..7c851e0 100644 --- a/ostd/channel.hh +++ b/ostd/channel.hh @@ -6,7 +6,8 @@ #ifndef OSTD_CHANNEL_HH #define OSTD_CHANNEL_HH -#include +#include +#include #include #include #include @@ -19,10 +20,8 @@ struct channel_error: std::logic_error { using std::logic_error::logic_error; }; -template +template struct channel { - using condition_variable_type = C; - /* default ctor works for default C */ channel(): p_state(new impl) {} @@ -67,11 +66,68 @@ struct channel { } private: + struct cond_iface { + cond_iface() {} + virtual ~cond_iface() {} + virtual void notify_one() = 0; + virtual void notify_all() = 0; + virtual void wait(std::unique_lock &) = 0; + }; + + template + struct cond_impl: cond_iface { + cond_impl(): p_cond() {} + template + cond_impl(F &func): p_cond(func()) {} + void notify_one() { + p_cond.notify_one(); + } + void notify_all() { + p_cond.notify_all(); + } + void wait(std::unique_lock &l) { + p_cond.wait(l); + } + private: + C p_cond; + }; + + struct cond { + cond() { + new (reinterpret_cast(&p_condbuf)) + cond_impl(); + } + template + cond(F &func) { + new (reinterpret_cast(&p_condbuf)) + cond_impl>(func); + } + ~cond() { + reinterpret_cast(&p_condbuf)->~cond_iface(); + } + + void notify_one() { + reinterpret_cast(&p_condbuf)->notify_one(); + } + void notify_all() { + reinterpret_cast(&p_condbuf)->notify_all(); + } + void wait(std::unique_lock &l) { + reinterpret_cast(&p_condbuf)->wait(l); + } + + private: + std::aligned_storage_t) + ), alignof(std::condition_variable)> p_condbuf; + }; + struct impl { - impl() {} + impl() { + } template - impl(F &func): p_lock(), p_cond(func()) {} + impl(F &func): p_lock(), p_cond(func) {} template void put(U &&val) { @@ -118,7 +174,7 @@ private: std::list p_messages; mutable std::mutex p_lock; - C p_cond; + cond p_cond; bool p_closed = false; }; diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index 6016feb..5319eba 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -19,9 +19,6 @@ namespace ostd { struct thread_scheduler { - template - using channel_type = channel; - ~thread_scheduler() { join_all(); } @@ -174,9 +171,6 @@ private: }; public: - template - using channel_type = channel; - basic_simple_coroutine_scheduler( size_t ss = TR::default_size(), size_t cs = basic_stack_pool::DEFAULT_CHUNK_SIZE @@ -228,8 +222,8 @@ public: } template - channel make_channel() { - return channel{[this]() { + channel make_channel() { + return channel{[this]() { return coro_cond{*this}; }}; } @@ -354,9 +348,6 @@ private: }; public: - template - using channel_type = channel; - basic_coroutine_scheduler( size_t ss = TR::default_size(), size_t cs = basic_stack_pool::DEFAULT_CHUNK_SIZE @@ -409,8 +400,8 @@ public: } template - channel make_channel() { - return channel{[this]() { + channel make_channel() { + return channel{[this]() { return task_cond{*this}; }}; } @@ -551,7 +542,7 @@ inline void yield(S &sched) { } template -inline auto make_channel(S &sched) -> typename S::template channel_type { +inline channel make_channel(S &sched) { return sched.template make_channel(); } diff --git a/ostd/thread_pool.hh b/ostd/thread_pool.hh index 0a54b50..ada48b2 100644 --- a/ostd/thread_pool.hh +++ b/ostd/thread_pool.hh @@ -60,7 +60,7 @@ namespace detail { template tpool_func(F &&func) { - if (sizeof(F) <= sizeof(p_buf)) { + if (sizeof(tpool_func_impl) <= sizeof(p_buf)) { p_func = ::new(reinterpret_cast(&p_buf)) tpool_func_impl{std::move(func)}; } else { @@ -81,7 +81,7 @@ namespace detail { } private: std::aligned_storage_t< - sizeof(std::packaged_task) + sizeof(void *) + sizeof(tpool_func_impl>) > p_buf; tpool_func_base *p_func; };