Commit 2c76542e authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Make fibers::Baton awaitable

Summary:
This extends fibers::Baton to support any awaiter (not just folly::Fiber) which makes it possible to integrate it with coroutines.
This will make it easy to integrate existing synchronization primitives that use fibers::Baton with coroutines.

Reviewed By: lewissbaker

Differential Revision: D8580775

fbshipit-source-id: 8f8593793d3012e0470b8f9b3bcf2713145d3573
parent 4bfcf7b4
......@@ -278,4 +278,27 @@ TEST(Coro, AwaitableWithMemberOperator) {
.getVia(&executor));
}
coro::Task<int> taskBaton(fibers::Baton& baton) {
co_await baton;
co_return 42;
}
TEST(Coro, Baton) {
ManualExecutor executor;
fibers::Baton baton;
auto future = via(&executor, taskBaton(baton));
EXPECT_FALSE(future.await_ready());
executor.run();
EXPECT_FALSE(future.await_ready());
baton.post();
executor.run();
EXPECT_TRUE(future.await_ready());
EXPECT_EQ(42, future.get());
}
#endif
......@@ -19,6 +19,21 @@
namespace folly {
namespace fibers {
class Baton::FiberWaiter : public Baton::Waiter {
public:
void setFiber(Fiber& fiber) {
DCHECK(!fiber_);
fiber_ = &fiber;
}
void post() override {
fiber_->resume();
}
private:
Fiber* fiber_{nullptr};
};
inline Baton::Baton() : Baton(NO_WAITER) {
assert(Baton(NO_WAITER).futex_.futex == static_cast<uint32_t>(NO_WAITER));
assert(Baton(POSTED).futex_.futex == static_cast<uint32_t>(POSTED));
......@@ -28,7 +43,7 @@ inline Baton::Baton() : Baton(NO_WAITER) {
static_cast<uint32_t>(THREAD_WAITING));
assert(futex_.futex.is_lock_free());
assert(waitingFiber_.is_lock_free());
assert(waiter_.is_lock_free());
}
template <typename F>
......@@ -44,20 +59,10 @@ void Baton::wait(F&& mainContextFunc) {
template <typename F>
void Baton::waitFiber(FiberManager& fm, F&& mainContextFunc) {
auto& waitingFiber = waitingFiber_;
auto f = [&mainContextFunc, &waitingFiber](Fiber& fiber) mutable {
auto baton_fiber = waitingFiber.load();
do {
if (LIKELY(baton_fiber == NO_WAITER)) {
continue;
} else if (baton_fiber == POSTED || baton_fiber == TIMEOUT) {
fiber.resume();
break;
} else {
throw std::logic_error("Some Fiber is already waiting on this Baton.");
}
} while (!waitingFiber.compare_exchange_weak(
baton_fiber, reinterpret_cast<intptr_t>(&fiber)));
FiberWaiter waiter;
auto f = [this, &mainContextFunc, &waiter](Fiber& fiber) mutable {
waiter.setFiber(fiber);
setWaiter(waiter);
mainContextFunc();
};
......@@ -89,7 +94,7 @@ bool Baton::try_wait_for(
waitFiber(*fm, static_cast<F&&>(mainContextFunc));
auto posted = waitingFiber_ == POSTED;
auto posted = waiter_ == POSTED;
if (!canceled) {
fm->timeoutManager_->cancel(id);
......
......@@ -24,6 +24,21 @@
namespace folly {
namespace fibers {
void Baton::setWaiter(Waiter& waiter) {
auto curr_waiter = waiter_.load();
do {
if (LIKELY(curr_waiter == NO_WAITER)) {
continue;
} else if (curr_waiter == POSTED || curr_waiter == TIMEOUT) {
waiter.post();
break;
} else {
throw std::logic_error("Some waiter is already waiting on this Baton.");
}
} while (!waiter_.compare_exchange_weak(
curr_waiter, reinterpret_cast<intptr_t>(&waiter)));
}
void Baton::wait() {
wait([]() {});
}
......@@ -43,34 +58,34 @@ void Baton::wait(TimeoutHandler& timeoutHandler) {
void Baton::waitThread() {
if (spinWaitForEarlyPost()) {
assert(waitingFiber_.load(std::memory_order_acquire) == POSTED);
assert(waiter_.load(std::memory_order_acquire) == POSTED);
return;
}
auto fiber = waitingFiber_.load();
auto waiter = waiter_.load();
if (LIKELY(
fiber == NO_WAITER &&
waitingFiber_.compare_exchange_strong(fiber, THREAD_WAITING))) {
waiter == NO_WAITER &&
waiter_.compare_exchange_strong(waiter, THREAD_WAITING))) {
do {
folly::detail::MemoryIdler::futexWait(
futex_.futex, uint32_t(THREAD_WAITING));
fiber = waitingFiber_.load(std::memory_order_acquire);
} while (fiber == THREAD_WAITING);
waiter = waiter_.load(std::memory_order_acquire);
} while (waiter == THREAD_WAITING);
}
if (LIKELY(fiber == POSTED)) {
if (LIKELY(waiter == POSTED)) {
return;
}
// Handle errors
if (fiber == TIMEOUT) {
if (waiter == TIMEOUT) {
throw std::logic_error("Thread baton can't have timeout status");
}
if (fiber == THREAD_WAITING) {
if (waiter == THREAD_WAITING) {
throw std::logic_error("Other thread is already waiting on this baton");
}
throw std::logic_error("Other fiber is already waiting on this baton");
throw std::logic_error("Other waiter is already waiting on this baton");
}
bool Baton::spinWaitForEarlyPost() {
......@@ -94,15 +109,15 @@ bool Baton::spinWaitForEarlyPost() {
bool Baton::timedWaitThread(TimeoutController::Duration timeout) {
if (spinWaitForEarlyPost()) {
assert(waitingFiber_.load(std::memory_order_acquire) == POSTED);
assert(waiter_.load(std::memory_order_acquire) == POSTED);
return true;
}
auto fiber = waitingFiber_.load();
auto waiter = waiter_.load();
if (LIKELY(
fiber == NO_WAITER &&
waitingFiber_.compare_exchange_strong(fiber, THREAD_WAITING))) {
waiter == NO_WAITER &&
waiter_.compare_exchange_strong(waiter, THREAD_WAITING))) {
auto deadline = TimeoutController::Clock::now() + timeout;
do {
const auto wait_rv =
......@@ -110,22 +125,22 @@ bool Baton::timedWaitThread(TimeoutController::Duration timeout) {
if (wait_rv == folly::detail::FutexResult::TIMEDOUT) {
return false;
}
fiber = waitingFiber_.load(std::memory_order_relaxed);
} while (fiber == THREAD_WAITING);
waiter = waiter_.load(std::memory_order_relaxed);
} while (waiter == THREAD_WAITING);
}
if (LIKELY(fiber == POSTED)) {
if (LIKELY(waiter == POSTED)) {
return true;
}
// Handle errors
if (fiber == TIMEOUT) {
if (waiter == TIMEOUT) {
throw std::logic_error("Thread baton can't have timeout status");
}
if (fiber == THREAD_WAITING) {
if (waiter == THREAD_WAITING) {
throw std::logic_error("Other thread is already waiting on this baton");
}
throw std::logic_error("Other fiber is already waiting on this baton");
throw std::logic_error("Other waiter is already waiting on this baton");
}
void Baton::post() {
......@@ -133,22 +148,22 @@ void Baton::post() {
}
void Baton::postHelper(intptr_t new_value) {
auto fiber = waitingFiber_.load();
auto waiter = waiter_.load();
do {
if (fiber == THREAD_WAITING) {
if (waiter == THREAD_WAITING) {
assert(new_value == POSTED);
return postThread();
}
if (fiber == POSTED || fiber == TIMEOUT) {
if (waiter == POSTED || waiter == TIMEOUT) {
return;
}
} while (!waitingFiber_.compare_exchange_weak(fiber, new_value));
} while (!waiter_.compare_exchange_weak(waiter, new_value));
if (fiber != NO_WAITER) {
reinterpret_cast<Fiber*>(fiber)->resume();
if (waiter != NO_WAITER) {
reinterpret_cast<Waiter*>(waiter)->post();
}
}
......@@ -159,7 +174,7 @@ bool Baton::try_wait() {
void Baton::postThread() {
auto expected = THREAD_WAITING;
if (!waitingFiber_.compare_exchange_strong(expected, POSTED)) {
if (!waiter_.compare_exchange_strong(expected, POSTED)) {
return;
}
......@@ -167,8 +182,7 @@ void Baton::postThread() {
}
void Baton::reset() {
waitingFiber_.store(NO_WAITER, std::memory_order_relaxed);
;
waiter_.store(NO_WAITER, std::memory_order_relaxed);
}
void Baton::TimeoutHandler::scheduleTimeout(
......
......@@ -17,9 +17,14 @@
#include <atomic>
#include <folly/Portability.h>
#include <folly/detail/Futex.h>
#include <folly/fibers/TimeoutController.h>
#if FOLLY_HAS_COROUTINES
#include <experimental/coroutine>
#endif
namespace folly {
namespace fibers {
......@@ -36,15 +41,28 @@ class Baton {
public:
class TimeoutHandler;
class Waiter {
public:
virtual void post() = 0;
virtual ~Waiter() {}
};
Baton();
~Baton() {}
bool ready() const {
auto state = waitingFiber_.load();
auto state = waiter_.load();
return state == POSTED;
}
/**
* Registers a waiter for the baton. The waiter will be notified when
* the baton is posted.
*/
void setWaiter(Waiter& waiter);
/**
* Puts active fiber to sleep. Returns when post is called.
*/
......@@ -213,6 +231,8 @@ class Baton {
};
private:
class FiberWaiter;
enum {
/**
* Must be positive. If multiple threads are actively using a
......@@ -232,7 +252,7 @@ class Baton {
PreBlockAttempts = 300,
};
explicit Baton(intptr_t state) : waitingFiber_(state) {}
explicit Baton(intptr_t state) : waiter_(state) {}
void postHelper(intptr_t new_value);
void postThread();
......@@ -256,13 +276,48 @@ class Baton {
static constexpr intptr_t THREAD_WAITING = -3;
union {
std::atomic<intptr_t> waitingFiber_;
std::atomic<intptr_t> waiter_;
struct {
folly::detail::Futex<> futex{};
int32_t _unused_packing;
} futex_;
};
};
#if FOLLY_HAS_COROUTINES
namespace detail {
class BatonAwaitableWaiter : public Baton::Waiter {
public:
explicit BatonAwaitableWaiter(Baton& baton) : baton_(baton) {}
void post() override {
assert(h_);
h_();
}
bool await_ready() const {
return baton_.ready();
}
void await_resume() {}
void await_suspend(std::experimental::coroutine_handle<> h) {
assert(!h_);
h_ = std::move(h);
baton_.setWaiter(*this);
}
private:
std::experimental::coroutine_handle<> h_;
Baton& baton_;
};
} // namespace detail
inline detail::BatonAwaitableWaiter /* implicit */ operator co_await(
Baton& baton) {
return detail::BatonAwaitableWaiter(baton);
}
#endif
} // namespace fibers
} // namespace folly
......
......@@ -21,9 +21,6 @@
#include <memory>
#include <type_traits>
#include <vector>
#if FOLLY_HAS_COROUTINES
#include <experimental/coroutine>
#endif
#include <folly/Optional.h>
#include <folly/Portability.h>
......@@ -38,6 +35,10 @@
#include <folly/futures/detail/Types.h>
#include <folly/lang/Exception.h>
#if FOLLY_HAS_COROUTINES
#include <experimental/coroutine>
#endif
// boring predeclarations and details
#include <folly/futures/Future-pre.h>
......@@ -1989,13 +1990,13 @@ class FutureRefAwaitable {
} // namespace detail
template <typename T>
detail::FutureRefAwaitable<T>
inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>& future) {
return detail::FutureRefAwaitable<T>(future);
}
template <typename T>
detail::FutureRefAwaitable<T>
inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>&& future) {
return detail::FutureRefAwaitable<T>(future);
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment