Commit 2c76542e authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Make fibers::Baton awaitable

Summary:
This extends fibers::Baton to support any awaiter (not just folly::Fiber) which makes it possible to integrate it with coroutines.
This will make it easy to integrate existing synchronization primitives that use fibers::Baton with coroutines.

Reviewed By: lewissbaker

Differential Revision: D8580775

fbshipit-source-id: 8f8593793d3012e0470b8f9b3bcf2713145d3573
parent 4bfcf7b4
...@@ -278,4 +278,27 @@ TEST(Coro, AwaitableWithMemberOperator) { ...@@ -278,4 +278,27 @@ TEST(Coro, AwaitableWithMemberOperator) {
.getVia(&executor)); .getVia(&executor));
} }
coro::Task<int> taskBaton(fibers::Baton& baton) {
co_await baton;
co_return 42;
}
TEST(Coro, Baton) {
ManualExecutor executor;
fibers::Baton baton;
auto future = via(&executor, taskBaton(baton));
EXPECT_FALSE(future.await_ready());
executor.run();
EXPECT_FALSE(future.await_ready());
baton.post();
executor.run();
EXPECT_TRUE(future.await_ready());
EXPECT_EQ(42, future.get());
}
#endif #endif
...@@ -19,6 +19,21 @@ ...@@ -19,6 +19,21 @@
namespace folly { namespace folly {
namespace fibers { namespace fibers {
class Baton::FiberWaiter : public Baton::Waiter {
public:
void setFiber(Fiber& fiber) {
DCHECK(!fiber_);
fiber_ = &fiber;
}
void post() override {
fiber_->resume();
}
private:
Fiber* fiber_{nullptr};
};
inline Baton::Baton() : Baton(NO_WAITER) { inline Baton::Baton() : Baton(NO_WAITER) {
assert(Baton(NO_WAITER).futex_.futex == static_cast<uint32_t>(NO_WAITER)); assert(Baton(NO_WAITER).futex_.futex == static_cast<uint32_t>(NO_WAITER));
assert(Baton(POSTED).futex_.futex == static_cast<uint32_t>(POSTED)); assert(Baton(POSTED).futex_.futex == static_cast<uint32_t>(POSTED));
...@@ -28,7 +43,7 @@ inline Baton::Baton() : Baton(NO_WAITER) { ...@@ -28,7 +43,7 @@ inline Baton::Baton() : Baton(NO_WAITER) {
static_cast<uint32_t>(THREAD_WAITING)); static_cast<uint32_t>(THREAD_WAITING));
assert(futex_.futex.is_lock_free()); assert(futex_.futex.is_lock_free());
assert(waitingFiber_.is_lock_free()); assert(waiter_.is_lock_free());
} }
template <typename F> template <typename F>
...@@ -44,20 +59,10 @@ void Baton::wait(F&& mainContextFunc) { ...@@ -44,20 +59,10 @@ void Baton::wait(F&& mainContextFunc) {
template <typename F> template <typename F>
void Baton::waitFiber(FiberManager& fm, F&& mainContextFunc) { void Baton::waitFiber(FiberManager& fm, F&& mainContextFunc) {
auto& waitingFiber = waitingFiber_; FiberWaiter waiter;
auto f = [&mainContextFunc, &waitingFiber](Fiber& fiber) mutable { auto f = [this, &mainContextFunc, &waiter](Fiber& fiber) mutable {
auto baton_fiber = waitingFiber.load(); waiter.setFiber(fiber);
do { setWaiter(waiter);
if (LIKELY(baton_fiber == NO_WAITER)) {
continue;
} else if (baton_fiber == POSTED || baton_fiber == TIMEOUT) {
fiber.resume();
break;
} else {
throw std::logic_error("Some Fiber is already waiting on this Baton.");
}
} while (!waitingFiber.compare_exchange_weak(
baton_fiber, reinterpret_cast<intptr_t>(&fiber)));
mainContextFunc(); mainContextFunc();
}; };
...@@ -89,7 +94,7 @@ bool Baton::try_wait_for( ...@@ -89,7 +94,7 @@ bool Baton::try_wait_for(
waitFiber(*fm, static_cast<F&&>(mainContextFunc)); waitFiber(*fm, static_cast<F&&>(mainContextFunc));
auto posted = waitingFiber_ == POSTED; auto posted = waiter_ == POSTED;
if (!canceled) { if (!canceled) {
fm->timeoutManager_->cancel(id); fm->timeoutManager_->cancel(id);
......
...@@ -24,6 +24,21 @@ ...@@ -24,6 +24,21 @@
namespace folly { namespace folly {
namespace fibers { namespace fibers {
void Baton::setWaiter(Waiter& waiter) {
auto curr_waiter = waiter_.load();
do {
if (LIKELY(curr_waiter == NO_WAITER)) {
continue;
} else if (curr_waiter == POSTED || curr_waiter == TIMEOUT) {
waiter.post();
break;
} else {
throw std::logic_error("Some waiter is already waiting on this Baton.");
}
} while (!waiter_.compare_exchange_weak(
curr_waiter, reinterpret_cast<intptr_t>(&waiter)));
}
void Baton::wait() { void Baton::wait() {
wait([]() {}); wait([]() {});
} }
...@@ -43,34 +58,34 @@ void Baton::wait(TimeoutHandler& timeoutHandler) { ...@@ -43,34 +58,34 @@ void Baton::wait(TimeoutHandler& timeoutHandler) {
void Baton::waitThread() { void Baton::waitThread() {
if (spinWaitForEarlyPost()) { if (spinWaitForEarlyPost()) {
assert(waitingFiber_.load(std::memory_order_acquire) == POSTED); assert(waiter_.load(std::memory_order_acquire) == POSTED);
return; return;
} }
auto fiber = waitingFiber_.load(); auto waiter = waiter_.load();
if (LIKELY( if (LIKELY(
fiber == NO_WAITER && waiter == NO_WAITER &&
waitingFiber_.compare_exchange_strong(fiber, THREAD_WAITING))) { waiter_.compare_exchange_strong(waiter, THREAD_WAITING))) {
do { do {
folly::detail::MemoryIdler::futexWait( folly::detail::MemoryIdler::futexWait(
futex_.futex, uint32_t(THREAD_WAITING)); futex_.futex, uint32_t(THREAD_WAITING));
fiber = waitingFiber_.load(std::memory_order_acquire); waiter = waiter_.load(std::memory_order_acquire);
} while (fiber == THREAD_WAITING); } while (waiter == THREAD_WAITING);
} }
if (LIKELY(fiber == POSTED)) { if (LIKELY(waiter == POSTED)) {
return; return;
} }
// Handle errors // Handle errors
if (fiber == TIMEOUT) { if (waiter == TIMEOUT) {
throw std::logic_error("Thread baton can't have timeout status"); throw std::logic_error("Thread baton can't have timeout status");
} }
if (fiber == THREAD_WAITING) { if (waiter == THREAD_WAITING) {
throw std::logic_error("Other thread is already waiting on this baton"); throw std::logic_error("Other thread is already waiting on this baton");
} }
throw std::logic_error("Other fiber is already waiting on this baton"); throw std::logic_error("Other waiter is already waiting on this baton");
} }
bool Baton::spinWaitForEarlyPost() { bool Baton::spinWaitForEarlyPost() {
...@@ -94,15 +109,15 @@ bool Baton::spinWaitForEarlyPost() { ...@@ -94,15 +109,15 @@ bool Baton::spinWaitForEarlyPost() {
bool Baton::timedWaitThread(TimeoutController::Duration timeout) { bool Baton::timedWaitThread(TimeoutController::Duration timeout) {
if (spinWaitForEarlyPost()) { if (spinWaitForEarlyPost()) {
assert(waitingFiber_.load(std::memory_order_acquire) == POSTED); assert(waiter_.load(std::memory_order_acquire) == POSTED);
return true; return true;
} }
auto fiber = waitingFiber_.load(); auto waiter = waiter_.load();
if (LIKELY( if (LIKELY(
fiber == NO_WAITER && waiter == NO_WAITER &&
waitingFiber_.compare_exchange_strong(fiber, THREAD_WAITING))) { waiter_.compare_exchange_strong(waiter, THREAD_WAITING))) {
auto deadline = TimeoutController::Clock::now() + timeout; auto deadline = TimeoutController::Clock::now() + timeout;
do { do {
const auto wait_rv = const auto wait_rv =
...@@ -110,22 +125,22 @@ bool Baton::timedWaitThread(TimeoutController::Duration timeout) { ...@@ -110,22 +125,22 @@ bool Baton::timedWaitThread(TimeoutController::Duration timeout) {
if (wait_rv == folly::detail::FutexResult::TIMEDOUT) { if (wait_rv == folly::detail::FutexResult::TIMEDOUT) {
return false; return false;
} }
fiber = waitingFiber_.load(std::memory_order_relaxed); waiter = waiter_.load(std::memory_order_relaxed);
} while (fiber == THREAD_WAITING); } while (waiter == THREAD_WAITING);
} }
if (LIKELY(fiber == POSTED)) { if (LIKELY(waiter == POSTED)) {
return true; return true;
} }
// Handle errors // Handle errors
if (fiber == TIMEOUT) { if (waiter == TIMEOUT) {
throw std::logic_error("Thread baton can't have timeout status"); throw std::logic_error("Thread baton can't have timeout status");
} }
if (fiber == THREAD_WAITING) { if (waiter == THREAD_WAITING) {
throw std::logic_error("Other thread is already waiting on this baton"); throw std::logic_error("Other thread is already waiting on this baton");
} }
throw std::logic_error("Other fiber is already waiting on this baton"); throw std::logic_error("Other waiter is already waiting on this baton");
} }
void Baton::post() { void Baton::post() {
...@@ -133,22 +148,22 @@ void Baton::post() { ...@@ -133,22 +148,22 @@ void Baton::post() {
} }
void Baton::postHelper(intptr_t new_value) { void Baton::postHelper(intptr_t new_value) {
auto fiber = waitingFiber_.load(); auto waiter = waiter_.load();
do { do {
if (fiber == THREAD_WAITING) { if (waiter == THREAD_WAITING) {
assert(new_value == POSTED); assert(new_value == POSTED);
return postThread(); return postThread();
} }
if (fiber == POSTED || fiber == TIMEOUT) { if (waiter == POSTED || waiter == TIMEOUT) {
return; return;
} }
} while (!waitingFiber_.compare_exchange_weak(fiber, new_value)); } while (!waiter_.compare_exchange_weak(waiter, new_value));
if (fiber != NO_WAITER) { if (waiter != NO_WAITER) {
reinterpret_cast<Fiber*>(fiber)->resume(); reinterpret_cast<Waiter*>(waiter)->post();
} }
} }
...@@ -159,7 +174,7 @@ bool Baton::try_wait() { ...@@ -159,7 +174,7 @@ bool Baton::try_wait() {
void Baton::postThread() { void Baton::postThread() {
auto expected = THREAD_WAITING; auto expected = THREAD_WAITING;
if (!waitingFiber_.compare_exchange_strong(expected, POSTED)) { if (!waiter_.compare_exchange_strong(expected, POSTED)) {
return; return;
} }
...@@ -167,8 +182,7 @@ void Baton::postThread() { ...@@ -167,8 +182,7 @@ void Baton::postThread() {
} }
void Baton::reset() { void Baton::reset() {
waitingFiber_.store(NO_WAITER, std::memory_order_relaxed); waiter_.store(NO_WAITER, std::memory_order_relaxed);
;
} }
void Baton::TimeoutHandler::scheduleTimeout( void Baton::TimeoutHandler::scheduleTimeout(
......
...@@ -17,9 +17,14 @@ ...@@ -17,9 +17,14 @@
#include <atomic> #include <atomic>
#include <folly/Portability.h>
#include <folly/detail/Futex.h> #include <folly/detail/Futex.h>
#include <folly/fibers/TimeoutController.h> #include <folly/fibers/TimeoutController.h>
#if FOLLY_HAS_COROUTINES
#include <experimental/coroutine>
#endif
namespace folly { namespace folly {
namespace fibers { namespace fibers {
...@@ -36,15 +41,28 @@ class Baton { ...@@ -36,15 +41,28 @@ class Baton {
public: public:
class TimeoutHandler; class TimeoutHandler;
class Waiter {
public:
virtual void post() = 0;
virtual ~Waiter() {}
};
Baton(); Baton();
~Baton() {} ~Baton() {}
bool ready() const { bool ready() const {
auto state = waitingFiber_.load(); auto state = waiter_.load();
return state == POSTED; return state == POSTED;
} }
/**
* Registers a waiter for the baton. The waiter will be notified when
* the baton is posted.
*/
void setWaiter(Waiter& waiter);
/** /**
* Puts active fiber to sleep. Returns when post is called. * Puts active fiber to sleep. Returns when post is called.
*/ */
...@@ -213,6 +231,8 @@ class Baton { ...@@ -213,6 +231,8 @@ class Baton {
}; };
private: private:
class FiberWaiter;
enum { enum {
/** /**
* Must be positive. If multiple threads are actively using a * Must be positive. If multiple threads are actively using a
...@@ -232,7 +252,7 @@ class Baton { ...@@ -232,7 +252,7 @@ class Baton {
PreBlockAttempts = 300, PreBlockAttempts = 300,
}; };
explicit Baton(intptr_t state) : waitingFiber_(state) {} explicit Baton(intptr_t state) : waiter_(state) {}
void postHelper(intptr_t new_value); void postHelper(intptr_t new_value);
void postThread(); void postThread();
...@@ -256,13 +276,48 @@ class Baton { ...@@ -256,13 +276,48 @@ class Baton {
static constexpr intptr_t THREAD_WAITING = -3; static constexpr intptr_t THREAD_WAITING = -3;
union { union {
std::atomic<intptr_t> waitingFiber_; std::atomic<intptr_t> waiter_;
struct { struct {
folly::detail::Futex<> futex{}; folly::detail::Futex<> futex{};
int32_t _unused_packing; int32_t _unused_packing;
} futex_; } futex_;
}; };
}; };
#if FOLLY_HAS_COROUTINES
namespace detail {
class BatonAwaitableWaiter : public Baton::Waiter {
public:
explicit BatonAwaitableWaiter(Baton& baton) : baton_(baton) {}
void post() override {
assert(h_);
h_();
}
bool await_ready() const {
return baton_.ready();
}
void await_resume() {}
void await_suspend(std::experimental::coroutine_handle<> h) {
assert(!h_);
h_ = std::move(h);
baton_.setWaiter(*this);
}
private:
std::experimental::coroutine_handle<> h_;
Baton& baton_;
};
} // namespace detail
inline detail::BatonAwaitableWaiter /* implicit */ operator co_await(
Baton& baton) {
return detail::BatonAwaitableWaiter(baton);
}
#endif
} // namespace fibers } // namespace fibers
} // namespace folly } // namespace folly
......
...@@ -21,9 +21,6 @@ ...@@ -21,9 +21,6 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#if FOLLY_HAS_COROUTINES
#include <experimental/coroutine>
#endif
#include <folly/Optional.h> #include <folly/Optional.h>
#include <folly/Portability.h> #include <folly/Portability.h>
...@@ -38,6 +35,10 @@ ...@@ -38,6 +35,10 @@
#include <folly/futures/detail/Types.h> #include <folly/futures/detail/Types.h>
#include <folly/lang/Exception.h> #include <folly/lang/Exception.h>
#if FOLLY_HAS_COROUTINES
#include <experimental/coroutine>
#endif
// boring predeclarations and details // boring predeclarations and details
#include <folly/futures/Future-pre.h> #include <folly/futures/Future-pre.h>
...@@ -1989,13 +1990,13 @@ class FutureRefAwaitable { ...@@ -1989,13 +1990,13 @@ class FutureRefAwaitable {
} // namespace detail } // namespace detail
template <typename T> template <typename T>
detail::FutureRefAwaitable<T> inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>& future) { /* implicit */ operator co_await(Future<T>& future) {
return detail::FutureRefAwaitable<T>(future); return detail::FutureRefAwaitable<T>(future);
} }
template <typename T> template <typename T>
detail::FutureRefAwaitable<T> inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>&& future) { /* implicit */ operator co_await(Future<T>&& future) {
return detail::FutureRefAwaitable<T>(future); return detail::FutureRefAwaitable<T>(future);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment