From 023af03361e889fe095dd2a7e35a43e9fed34136 Mon Sep 17 00:00:00 2001 From: q66 Date: Thu, 23 Mar 2017 11:35:23 +0100 Subject: [PATCH] spawn/make_channel/yield without explicitly specifying scheduler --- examples/concurrency.cc | 14 ++-- ostd/concurrency.hh | 143 +++++++++++++++++++++++++++++++++++++--- src/concurrency.cc | 1 + 3 files changed, 143 insertions(+), 15 deletions(-) diff --git a/examples/concurrency.cc b/examples/concurrency.cc index 3ee092a..c219b2d 100644 --- a/examples/concurrency.cc +++ b/examples/concurrency.cc @@ -8,15 +8,15 @@ int main() { * task, which may or may not run in parallel with the other one depending * on the scheduler currently in use - several schedulers are shown */ - auto foo = [](auto &sched) { + auto foo = []() { auto arr = ostd::iter({ 150, 38, 76, 25, 67, 18, -15, 215, 25, -10 }); - auto c = make_channel(sched); + auto c = make_channel(); auto f = [](auto c, auto half) { c.put(foldl(half, 0)); }; - spawn(sched, f, c, arr.slice(0, arr.size() / 2)); - spawn(sched, f, c, arr + (arr.size() / 2)); + spawn(f, c, arr.slice(0, arr.size() / 2)); + spawn(f, c, arr + (arr.size() / 2)); int a = c.get(); int b = c.get(); @@ -30,7 +30,7 @@ int main() { thread_scheduler tsched; tsched.start([&tsched, &foo]() { writeln("(1) 1:1 scheduler: starting..."); - foo(tsched); + foo(); writeln("(1) 1:1 scheduler: finishing..."); }); writeln(); @@ -42,7 +42,7 @@ int main() { simple_coroutine_scheduler scsched; scsched.start([&scsched, &foo]() { writeln("(2) N:1 scheduler: starting..."); - foo(scsched); + foo(); writeln("(2) N:1 scheduler: finishing..."); }); writeln(); @@ -55,7 +55,7 @@ int main() { coroutine_scheduler csched; csched.start([&csched, &foo]() { writeln("(3) M:N scheduler: starting..."); - foo(csched); + foo(); writeln("(3) M:N scheduler: finishing..."); }); } diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index 61a38a3..f82c94b 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -19,14 +19,113 @@ namespace ostd { -struct thread_scheduler { - ~thread_scheduler() { - join_all(); - } +namespace detail { + struct sched_iface_base { + sched_iface_base() {} + virtual ~sched_iface_base() {} + virtual void spawn(std::function &&func) = 0; + virtual void yield() = 0; + virtual generic_condvar make_condition() = 0; + }; + + template + struct sched_iface_impl final: sched_iface_base { + sched_iface_impl(S &sched): p_sched(&sched) {} + + virtual void spawn(std::function &&func) { + p_sched->spawn(std::move(func)); + } + virtual void yield() { + p_sched->yield(); + } + virtual generic_condvar make_condition() { + return p_sched->make_condition(); + } + private: + S *p_sched; + }; + + struct sched_iface { + sched_iface() {} + + template + void set(S &sched) { + if (p_curr) { + p_curr->~sched_iface_base(); + } + new (reinterpret_cast(&p_buf)) + sched_iface_impl(sched); + p_curr = reinterpret_cast(&p_buf); + } + + void unset() { + if (p_curr) { + p_curr->~sched_iface_base(); + p_curr = nullptr; + } + } + + ~sched_iface() { + unset(); + } + + sched_iface(sched_iface const &) = delete; + sched_iface(sched_iface &&) = delete; + sched_iface &operator=(sched_iface const &) = delete; + sched_iface &operator=(sched_iface &&) = delete; + + void spawn(std::function &&func) { + p_curr->spawn(std::move(func)); + } + void yield() { + p_curr->yield(); + } + generic_condvar make_condition() { + return p_curr->make_condition(); + } + + private: + std::aligned_storage_t< + sizeof(detail::sched_iface_impl), + alignof(detail::sched_iface_impl) + > p_buf; + sched_iface_base *p_curr = nullptr; + }; + + OSTD_EXPORT extern sched_iface current_sched_iface; + + struct sched_iface_owner { + sched_iface_owner() = delete; + + template + sched_iface_owner(S &sched) { + current_sched_iface.set(sched); + } + + sched_iface_owner(sched_iface_owner const &) = delete; + sched_iface_owner(sched_iface_owner &&) = delete; + sched_iface_owner &operator=(sched_iface_owner const &) = delete; + sched_iface_owner &operator=(sched_iface_owner &&) = delete; + + ~sched_iface_owner() { + current_sched_iface.unset(); + } + }; +} + +struct thread_scheduler { template auto start(F &&func, A &&...args) -> std::result_of_t { - return func(std::forward(args)...); + detail::sched_iface_owner iface{*this}; + if constexpr(std::is_same_v, void>) { + func(std::forward(args)...); + join_all(); + } else { + auto ret = func(std::forward(args)...); + join_all(); + return ret; + } } void spawn(std::function func) { @@ -65,6 +164,7 @@ private: } for (auto &t: p_threads) { t.join(); + p_threads.pop_front(); } } @@ -180,6 +280,7 @@ public: template auto start(F &&func, A &&...args) -> std::result_of_t { + detail::sched_iface_owner iface{*this}; using R = std::result_of_t; if constexpr(std::is_same_v) { func(std::forward(args)...); @@ -343,6 +444,8 @@ public: template auto start(F func, A &&...args) -> std::result_of_t { + detail::sched_iface_owner iface{*this}; + /* start with one task in the queue, this way we can * say we've finished when the task queue becomes empty */ @@ -511,7 +614,7 @@ using protected_coroutine_scheduler = basic_coroutine_scheduler; template -inline void spawn(S &sched, F &&func, A &&...args) { +inline void spawn_in(S &sched, F &&func, A &&...args) { if constexpr(sizeof...(A) == 0) { sched.spawn(std::forward(func)); } else { @@ -519,18 +622,42 @@ inline void spawn(S &sched, F &&func, A &&...args) { } } +template +inline void spawn(F &&func, A &&...args) { + if constexpr(sizeof...(A) == 0) { + detail::current_sched_iface.spawn(std::forward(func)); + } else { + detail::current_sched_iface.spawn( + std::bind(std::forward(func), std::forward(args)...) + ); + } +} + template -inline void yield(S &sched) { +inline void yield_in(S &sched) { sched.yield(); } +template +inline void yield() { + detail::current_sched_iface.yield(); +} + template -inline channel make_channel(S &sched) { +inline channel make_channel_in(S &sched) { return channel{[&sched]() { return sched.make_condition(); }}; } +template +inline channel make_channel() { + auto &sciface = detail::current_sched_iface; + return channel{[&sciface]() { + return sciface.make_condition(); + }}; +} + } /* namespace ostd */ #endif diff --git a/src/concurrency.cc b/src/concurrency.cc index 35f7678..6f313d8 100644 --- a/src/concurrency.cc +++ b/src/concurrency.cc @@ -8,6 +8,7 @@ namespace ostd { namespace detail { +OSTD_EXPORT sched_iface current_sched_iface; OSTD_EXPORT thread_local csched_task *current_csched_task = nullptr; } /* namespace detail */