diff --git a/examples/concurrency.cc b/examples/concurrency.cc index 5fb4eb3..3ee092a 100644 --- a/examples/concurrency.cc +++ b/examples/concurrency.cc @@ -18,7 +18,8 @@ int main() { spawn(sched, f, c, arr.slice(0, arr.size() / 2)); spawn(sched, f, c, arr + (arr.size() / 2)); - int a = c.get(), b = c.get(); + int a = c.get(); + int b = c.get(); writefln("%s + %s = %s", a, b, a + b); }; diff --git a/ostd/channel.hh b/ostd/channel.hh index bee7323..26817f5 100644 --- a/ostd/channel.hh +++ b/ostd/channel.hh @@ -6,6 +6,7 @@ #ifndef OSTD_CHANNEL_HH #define OSTD_CHANNEL_HH +#include #include #include #include diff --git a/ostd/concurrency.hh b/ostd/concurrency.hh index 8ddf19f..4023168 100644 --- a/ostd/concurrency.hh +++ b/ostd/concurrency.hh @@ -328,6 +328,12 @@ private: template void wait(L &l) noexcept { + /* lock until the task has been added to the wait queue, + * that ensures that any notify/notify_any has to wait + * until after the task has fully blocked... we can't + * use unique_lock or lock_guard because they're scoped + */ + p_sched.p_lock.lock(); l.unlock(); task *curr = task::current(); curr->waiting_on = this; @@ -493,8 +499,8 @@ private: task &c = *it; l.unlock(); c(); - l.lock(); if (c.dead()) { + l.lock(); p_running.erase(it); /* we're dead, notify all threads so they can be joined * we check all three, saves the other threads some re-waiting @@ -507,6 +513,7 @@ private: } } else if (!c.waiting_on) { /* reschedule to the end of the queue */ + l.lock(); p_available.splice(p_available.cend(), p_running, it); l.unlock(); p_cond.notify_one(); @@ -514,6 +521,8 @@ private: p_waiting.splice(p_waiting.cbegin(), p_running, it); c.next_waiting = c.waiting_on->p_waiting; c.waiting_on->p_waiting = &c; + /* wait locks the mutex, so manually unlock it here */ + p_lock.unlock(); } }