/* A thread pool that can be used standalone or within a more elaborate module. * * This file is part of OctaSTD. See COPYING.md for futher information. */ #ifndef OSTD_THREAD_POOL_HH #define OSTD_THREAD_POOL_HH #include #include #include #include #include #include #include #include #include #include namespace ostd { namespace detail { /* can be used as a base for any custom thread pool, internal */ struct thread_pool_base { virtual ~thread_pool_base() { destroy(); } void destroy() { { std::lock_guard l{p_lock}; if (!p_running) { return; } p_running = false; } p_cond.notify_all(); for (auto &tid: p_thrs) { tid.join(); p_cond.notify_all(); } } protected: thread_pool_base() {} thread_pool_base(thread_pool_base const &) = delete; thread_pool_base(thread_pool_base &&) = delete; thread_pool_base &operator=(thread_pool_base const &) = delete; thread_pool_base &operator=(thread_pool_base &&) = delete; template void start(size_t size) { p_running = true; auto tf = [this]() { thread_run(); }; for (size_t i = 0; i < size; ++i) { std::thread tid{tf}; if (!tid.joinable()) { throw std::runtime_error{"thread_pool worker failed"}; } p_thrs.push_back(std::move(tid)); } } template void push_task(F &&func) { { std::lock_guard l{this->p_lock}; if (!p_running) { throw std::runtime_error{"push on stopped thread_pool"}; } func(); } p_cond.notify_one(); } private: template void thread_run() { B &self = *static_cast(this); for (;;) { std::unique_lock l{p_lock}; while (p_running && self.empty()) { p_cond.wait(l); } if (!p_running && self.empty()) { return; } self.task_run(l); } } std::condition_variable p_cond; std::mutex p_lock; std::vector p_thrs; bool p_running = false; }; /* regular thread pool task, as lightweight as possible */ struct tpool_func_base { tpool_func_base() {} virtual ~tpool_func_base() {} virtual void clone(tpool_func_base *func) = 0; virtual void call() = 0; }; template struct tpool_func_impl: tpool_func_base { tpool_func_impl(F &&func): p_func(std::move(func)) {} void clone(tpool_func_base *p) { new (p) tpool_func_impl(std::move(p_func)); } void call() { p_func(); } private: F p_func; }; struct tpool_func { tpool_func() = delete; tpool_func(tpool_func const &) = delete; tpool_func &operator=(tpool_func const &) = delete; tpool_func(tpool_func &&func) { if (static_cast(func.p_func) == &func.p_buf) { p_func = reinterpret_cast(&p_buf); func.p_func->clone(p_func); } else { p_func = func.p_func; func.p_func = nullptr; } } template tpool_func(F &&func) { if (sizeof(F) <= sizeof(p_buf)) { p_func = ::new(reinterpret_cast(&p_buf)) tpool_func_impl{std::move(func)}; } else { p_func = new tpool_func_impl{std::move(func)}; } } ~tpool_func() { if (static_cast(p_func) == &p_buf) { p_func->~tpool_func_base(); } else { delete p_func; } } void operator()() { p_func->call(); } private: std::aligned_storage_t< sizeof(std::packaged_task) + sizeof(void *) > p_buf; tpool_func_base *p_func; }; } struct thread_pool: detail::thread_pool_base { private: friend struct detail::thread_pool_base; using base_t = detail::thread_pool_base; public: thread_pool(): base_t() {} void start(size_t size = std::thread::hardware_concurrency()) { base_t::template start(size); } template auto push(F &&func, A &&...args) -> std::future> { using R = std::result_of_t; std::packaged_task t; if constexpr(sizeof...(A) == 0) { t = std::packaged_task{std::forward(func)}; } else { t = std::packaged_task{ std::bind(std::forward(func), std::forward(args)...) }; } auto ret = t.get_future(); this->push_task([this, &t]() { p_tasks.emplace(std::move(t)); }); return ret; } private: bool empty() const { return p_tasks.empty(); } void task_run(std::unique_lock &l) { auto t{std::move(p_tasks.front())}; p_tasks.pop(); l.unlock(); t(); } std::queue p_tasks; }; } /* namespace ostd */ #endif