/* A thread pool that can be used standalone or within a more elaborate module. * * This file is part of OctaSTD. See COPYING.md for futher information. */ #ifndef OSTD_THREAD_POOL_HH #define OSTD_THREAD_POOL_HH #include #include #include #include #include #include #include #include #include namespace ostd { namespace detail { template using task_result_of = std::conditional_t< std::is_same_v, void, std::future >; } struct thread_pool { void start(size_t size = std::thread::hardware_concurrency()) { p_running = true; auto tf = [this]() { thread_run(); }; for (size_t i = 0; i < size; ++i) { std::thread tid{tf}; if (!tid.joinable()) { throw std::runtime_error{"thread_pool worker failed"}; } p_thrs.push_back(std::move(tid)); } } ~thread_pool() { destroy(); } void destroy() { std::unique_lock l{p_lock}; if (!p_running) { return; } p_running = false; l.unlock(); p_cond.notify_all(); for (auto &tid: p_thrs) { tid.join(); p_cond.notify_all(); } } template auto push(F &&func, A &&...args) -> detail::task_result_of> { using R = std::result_of_t; if constexpr(std::is_same_v) { /* void-returning funcs return void */ std::unique_lock l{p_lock}; if (!p_running) { throw std::runtime_error{"push on stopped thread_pool"}; } if constexpr(sizeof...(A) == 0) { p_tasks.push(std::forward(func)); } else { p_tasks.push( std::bind(std::forward(func), std::forward(args)...) ); } p_cond.notify_one(); } else { /* non-void-returning funcs return a future */ std::packaged_task t; if constexpr(sizeof...(A) == 0) { t = std::packaged_task{std::forward(func)}; } else { t = std::packaged_task{ std::bind(std::forward(func), std::forward(args)...) }; } auto ret = t.get_future(); std::unique_lock l{p_lock}; if (!p_running) { throw std::runtime_error{"push on stopped thread_pool"}; } p_tasks.emplace([t = std::move(t)]() { t(); }); p_cond.notify_one(); return ret; } } private: void thread_run() { for (;;) { std::unique_lock l{p_lock}; while (p_running && p_tasks.empty()) { p_cond.wait(l); } if (!p_running && p_tasks.empty()) { return; } auto t = std::move(p_tasks.front()); p_tasks.pop(); l.unlock(); t(); } } std::condition_variable p_cond; std::mutex p_lock; std::vector p_thrs; std::queue> p_tasks; bool p_running = false; }; } /* namespace ostd */ #endif