Commit c1990c7c authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Add coroutine support to fibers::Semaphore

Reviewed By: yfeldblum

Differential Revision: D13515209

fbshipit-source-id: 6d4688242a586b6e5558c62c1c6f3bb7c6595dfb
parent 43ea7bf4
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <folly/experimental/coro/BlockingWait.h> #include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Task.h> #include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/Utils.h> #include <folly/experimental/coro/Utils.h>
#include <folly/fibers/Semaphore.h>
#include <folly/io/async/ScopedEventBaseThread.h> #include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
...@@ -379,4 +380,61 @@ TEST(Coro, lambda) { ...@@ -379,4 +380,61 @@ TEST(Coro, lambda) {
executor.run(); executor.run();
EXPECT_TRUE(coroFuture.isReady()); EXPECT_TRUE(coroFuture.isReady());
} }
TEST(Coro, Semaphore) {
static constexpr size_t kTasks = 10;
static constexpr size_t kIterations = 10000;
static constexpr size_t kNumTokens = 10;
static constexpr size_t kNumThreads = 16;
fibers::Semaphore sem(kNumTokens);
struct Worker {
explicit Worker(fibers::Semaphore& s) : sem(s), t([&] { run(); }) {}
void run() {
folly::EventBase evb;
{
std::shared_ptr<folly::EventBase> completionCounter(
&evb, [](folly::EventBase* evb_) { evb_->terminateLoopSoon(); });
for (size_t i = 0; i < kTasks; ++i) {
coro::lambda([&, completionCounter]() -> coro::Task<void> {
for (size_t j = 0; j < kIterations; ++j) {
co_await sem.co_wait();
++counter;
sem.signal();
--counter;
EXPECT_LT(counter, kNumTokens);
EXPECT_GE(counter, 0);
}
})
.scheduleOn(&evb)
.start();
}
}
evb.loopForever();
}
fibers::Semaphore& sem;
int counter{0};
std::thread t;
};
std::vector<Worker> workers;
workers.reserve(kNumThreads);
for (size_t i = 0; i < kNumThreads; ++i) {
workers.emplace_back(sem);
}
for (auto& worker : workers) {
worker.t.join();
}
for (auto& worker : workers) {
EXPECT_EQ(0, worker.counter);
}
}
#endif #endif
...@@ -57,10 +57,8 @@ void Semaphore::signal() { ...@@ -57,10 +57,8 @@ void Semaphore::signal() {
std::memory_order_acquire)); std::memory_order_acquire));
} }
bool Semaphore::waitSlow() { bool Semaphore::waitSlow(folly::fibers::Baton& waitBaton) {
// Slow path, create a baton and acquire a mutex to update the wait list // Slow path, create a baton and acquire a mutex to update the wait list
folly::fibers::Baton waitBaton;
{ {
auto waitListLock = waitList_.wlock(); auto waitListLock = waitList_.wlock();
auto& waitList = *waitListLock; auto& waitList = *waitListLock;
...@@ -72,9 +70,7 @@ bool Semaphore::waitSlow() { ...@@ -72,9 +70,7 @@ bool Semaphore::waitSlow() {
// prepare baton and add to queue // prepare baton and add to queue
waitList.push(&waitBaton); waitList.push(&waitBaton);
} }
// If we managed to create a baton, wait on it // Signal to caller that we managed to push a baton
// This has to be done here so the mutex has been released
waitBaton.wait();
return true; return true;
} }
...@@ -82,9 +78,11 @@ void Semaphore::wait() { ...@@ -82,9 +78,11 @@ void Semaphore::wait() {
auto oldVal = tokens_.load(std::memory_order_acquire); auto oldVal = tokens_.load(std::memory_order_acquire);
do { do {
while (oldVal == 0) { while (oldVal == 0) {
folly::fibers::Baton waitBaton;
// If waitSlow fails it is because the token is non-zero by the time // If waitSlow fails it is because the token is non-zero by the time
// the lock is taken, so we can just continue round the loop // the lock is taken, so we can just continue round the loop
if (waitSlow()) { if (waitSlow(waitBaton)) {
waitBaton.wait();
return; return;
} }
oldVal = tokens_.load(std::memory_order_acquire); oldVal = tokens_.load(std::memory_order_acquire);
...@@ -96,6 +94,30 @@ void Semaphore::wait() { ...@@ -96,6 +94,30 @@ void Semaphore::wait() {
std::memory_order_acquire)); std::memory_order_acquire));
} }
#if FOLLY_HAS_COROUTINES
coro::Task<void> Semaphore::co_wait() {
auto oldVal = tokens_.load(std::memory_order_acquire);
do {
while (oldVal == 0) {
folly::fibers::Baton waitBaton;
// If waitSlow fails it is because the token is non-zero by the time
// the lock is taken, so we can just continue round the loop
if (waitSlow(waitBaton)) {
co_await waitBaton;
co_return;
}
oldVal = tokens_.load(std::memory_order_acquire);
}
} while (!tokens_.compare_exchange_weak(
oldVal,
oldVal - 1,
std::memory_order_release,
std::memory_order_acquire));
}
#endif
size_t Semaphore::getCapacity() const { size_t Semaphore::getCapacity() const {
return capacity_; return capacity_;
} }
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <folly/Synchronized.h> #include <folly/Synchronized.h>
#include <folly/fibers/Baton.h> #include <folly/fibers/Baton.h>
#if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/Task.h>
#endif
namespace folly { namespace folly {
namespace fibers { namespace fibers {
...@@ -45,10 +48,19 @@ class Semaphore { ...@@ -45,10 +48,19 @@ class Semaphore {
*/ */
void wait(); void wait();
#if FOLLY_HAS_COROUTINES
/*
* Wait for capacity in the semaphore.
*/
coro::Task<void> co_wait();
#endif
size_t getCapacity() const; size_t getCapacity() const;
private: private:
bool waitSlow(); bool waitSlow(folly::fibers::Baton& waitBaton);
bool signalSlow(); bool signalSlow();
size_t capacity_; size_t capacity_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment