diff --git a/ostd/thread_pool.hh b/ostd/thread_pool.hh index 964b7e2..678a0d0 100644 --- a/ostd/thread_pool.hh +++ b/ostd/thread_pool.hh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -18,13 +19,6 @@ namespace ostd { -namespace detail { - template - using task_result_of = std::conditional_t< - std::is_same_v, void, std::future - >; -} - struct thread_pool { void start(size_t size = std::thread::hardware_concurrency()) { p_running = true; @@ -61,48 +55,27 @@ struct thread_pool { template auto push(F &&func, A &&...args) -> - detail::task_result_of> + std::future> { using R = std::result_of_t; - if constexpr(std::is_same_v) { - /* void-returning funcs return void */ - { - std::lock_guard l{p_lock}; - if (!p_running) { - throw std::runtime_error{"push on stopped thread_pool"}; - } - if constexpr(sizeof...(A) == 0) { - p_tasks.push(std::forward(func)); - } else { - p_tasks.push( - std::bind(std::forward(func), std::forward(args)...) - ); - } + /* TODO: we can ditch the shared_ptr by implementing our own + * move-only backing representation to replace use of std::function + */ + auto t = std::make_shared>( + std::bind(std::forward(func), std::forward(args)...) + ); + auto ret = t->get_future(); + { + std::lock_guard l{p_lock}; + if (!p_running) { + throw std::runtime_error{"push on stopped thread_pool"}; } - p_cond.notify_one(); - } else { - /* non-void-returning funcs return a future */ - std::packaged_task t; - if constexpr(sizeof...(A) == 0) { - t = std::packaged_task{std::forward(func)}; - } else { - t = std::packaged_task{ - std::bind(std::forward(func), std::forward(args)...) - }; - } - auto ret = t.get_future(); - { - std::lock_guard l{p_lock}; - if (!p_running) { - throw std::runtime_error{"push on stopped thread_pool"}; - } - p_tasks.emplace([t = std::move(t)]() { - t(); - }); - } - p_cond.notify_one(); - return ret; + p_tasks.emplace([t = std::move(t)]() mutable { + (*t)(); + }); } + p_cond.notify_one(); + return ret; } private: