Commit e8eb0719 authored by Lewis Baker's avatar Lewis Baker Committed by Facebook GitHub Bot

Add support for saving/restoring async stack frames to default co_viaIfAsync() implementation

Summary:
This should avoid creating two temporary wrapper coroutines to apply
both co_viaIfAsync() and co_withAsyncStack() to awaitables when awaited
within a Task/AsyncGenerator coroutine.

The default co_viaIfAsync() implementation now customises the
co_withAsyncStack() CPO and saves and restores the async frame
itself if the wrapped awaitable does not support async stacks,
or otherwise passes the async frame through to the child awaitable
if it does support async stacks.

Reviewed By: andriigrynenko

Differential Revision: D24464513

fbshipit-source-id: f9751030f3c1d2725a00cb803ee46ee12dddb2ed
parent beb3e640
...@@ -22,10 +22,12 @@ ...@@ -22,10 +22,12 @@
#include <folly/Executor.h> #include <folly/Executor.h>
#include <folly/Traits.h> #include <folly/Traits.h>
#include <folly/experimental/coro/Traits.h> #include <folly/experimental/coro/Traits.h>
#include <folly/experimental/coro/WithAsyncStack.h>
#include <folly/experimental/coro/WithCancellation.h> #include <folly/experimental/coro/WithCancellation.h>
#include <folly/experimental/coro/detail/Malloc.h> #include <folly/experimental/coro/detail/Malloc.h>
#include <folly/io/async/Request.h> #include <folly/io/async/Request.h>
#include <folly/lang/CustomizationPoint.h> #include <folly/lang/CustomizationPoint.h>
#include <folly/tracing/AsyncStack.h>
#include <glog/logging.h> #include <glog/logging.h>
...@@ -37,150 +39,234 @@ namespace coro { ...@@ -37,150 +39,234 @@ namespace coro {
namespace detail { namespace detail {
class ViaCoroutine { class ViaCoroutinePromiseBase {
public: public:
class promise_type { static void* operator new(std::size_t size) {
public: return ::folly_coro_async_malloc(size);
// Passed as lvalue by compiler, but should have no other dependencies }
promise_type(folly::Executor::KeepAlive<>& executor) noexcept
: executor_(std::move(executor)) {}
static void* operator new(std::size_t size) { static void operator delete(void* ptr, std::size_t size) {
return ::folly_coro_async_malloc(size); ::folly_coro_async_free(ptr, size);
} }
static void operator delete(void* ptr, std::size_t size) { std::experimental::suspend_always initial_suspend() noexcept { return {}; }
::folly_coro_async_free(ptr, size);
}
ViaCoroutine get_return_object() noexcept { void return_void() noexcept {}
return ViaCoroutine{
std::experimental::coroutine_handle<promise_type>::from_promise( [[noreturn]] void unhandled_exception() noexcept {
*this)}; folly::assume_unreachable();
}
void setExecutor(folly::Executor::KeepAlive<> executor) noexcept {
executor_ = std::move(executor);
}
void setContinuation(
std::experimental::coroutine_handle<> continuation) noexcept {
continuation_ = continuation;
}
void setAsyncFrame(folly::AsyncStackFrame& frame) noexcept {
asyncFrame_ = &frame;
}
void setRequestContext(
std::shared_ptr<folly::RequestContext> context) noexcept {
context_ = std::move(context);
}
protected:
void scheduleContinuation() noexcept {
executor_->add([this]() noexcept { this->executeContinuation(); });
}
private:
void executeContinuation() noexcept {
RequestContextScopeGuard contextScope{std::move(context_)};
if (asyncFrame_ != nullptr) {
folly::resumeCoroutineWithNewAsyncStackRoot(continuation_, *asyncFrame_);
} else {
continuation_.resume();
} }
}
std::experimental::suspend_always initial_suspend() noexcept { return {}; } protected:
folly::Executor::KeepAlive<> executor_;
auto final_suspend() noexcept { std::experimental::coroutine_handle<> continuation_;
struct Awaiter { folly::AsyncStackFrame* asyncFrame_ = nullptr;
bool await_ready() noexcept { return false; } std::shared_ptr<RequestContext> context_;
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES void await_suspend( };
std::experimental::coroutine_handle<promise_type> coro) noexcept {
// Schedule resumption of the coroutine on the executor.
auto& promise = coro.promise();
if (!promise.context_) {
promise.context_ = RequestContext::saveContext();
}
promise.executor_->add([&promise]() noexcept { template <bool IsStackAware>
RequestContextScopeGuard contextScope{std::move(promise.context_)}; class ViaCoroutine {
promise.continuation_.resume(); public:
}); class promise_type : public ViaCoroutinePromiseBase {
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES void await_suspend(
std::experimental::coroutine_handle<promise_type> h) noexcept {
auto& promise = h.promise();
if (!promise.context_) {
promise.setRequestContext(RequestContext::saveContext());
} }
void await_resume() noexcept {}
};
return Awaiter{}; if constexpr (IsStackAware) {
} folly::deactivateAsyncStackFrame(promise.getAsyncFrame());
}
[[noreturn]] void unhandled_exception() noexcept { promise.scheduleContinuation();
LOG(FATAL) << "ViaCoroutine threw an unhandled exception"; }
}
void return_void() noexcept {} [[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); }
};
void setContinuation( public:
std::experimental::coroutine_handle<> continuation) noexcept { ViaCoroutine get_return_object() noexcept {
DCHECK(!continuation_); return ViaCoroutine{
continuation_ = continuation; std::experimental::coroutine_handle<promise_type>::from_promise(
*this)};
} }
void setContext(std::shared_ptr<RequestContext> context) noexcept { FinalAwaiter final_suspend() noexcept { return {}; }
context_ = std::move(context);
}
private: template <
folly::Executor::KeepAlive<> executor_; bool IsStackAware2 = IsStackAware,
std::experimental::coroutine_handle<> continuation_; std::enable_if_t<IsStackAware2, int> = 0>
std::shared_ptr<RequestContext> context_; folly::AsyncStackFrame& getAsyncFrame() noexcept {
DCHECK(this->asyncFrame_ != nullptr);
return *this->asyncFrame_;
}
}; };
ViaCoroutine(ViaCoroutine&& other) noexcept ViaCoroutine(ViaCoroutine&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {} : coro_(std::exchange(other.coro_, {})) {}
~ViaCoroutine() { destroy(); } ~ViaCoroutine() {
if (coro_) {
coro_.destroy();
}
}
ViaCoroutine& operator=(ViaCoroutine other) noexcept { static ViaCoroutine create(folly::Executor::KeepAlive<> executor) {
swap(other); ViaCoroutine coroutine = createImpl();
return *this; coroutine.setExecutor(std::move(executor));
return coroutine;
} }
void swap(ViaCoroutine& other) noexcept { std::swap(coro_, other.coro_); } void setExecutor(folly::Executor::KeepAlive<> executor) noexcept {
coro_.promise().setExecutor(std::move(executor));
}
std::experimental::coroutine_handle<> getWrappedCoroutine( void setContinuation(
std::experimental::coroutine_handle<> continuation) noexcept { std::experimental::coroutine_handle<> continuation) noexcept {
if (coro_) { coro_.promise().setContinuation(continuation);
coro_.promise().setContinuation(continuation);
return coro_;
} else {
return continuation;
}
} }
std::experimental::coroutine_handle<> getWrappedCoroutineWithSavedContext( void setAsyncFrame(folly::AsyncStackFrame& frame) noexcept {
std::experimental::coroutine_handle<> continuation) noexcept { coro_.promise().setAsyncFrame(frame);
coro_.promise().setContext(RequestContext::saveContext());
return getWrappedCoroutine(continuation);
} }
void destroy() { void destroy() noexcept {
if (coro_) { if (coro_) {
std::exchange(coro_, {}).destroy(); std::exchange(coro_, {}).destroy();
} }
} }
static ViaCoroutine create(folly::Executor::KeepAlive<> executor) { void saveContext() noexcept {
co_return; coro_.promise().setRequestContext(folly::RequestContext::saveContext());
} }
static ViaCoroutine createInline() noexcept { std::experimental::coroutine_handle<promise_type> getHandle() noexcept {
return ViaCoroutine{std::experimental::coroutine_handle<promise_type>{}}; return coro_;
} }
private: private:
friend class promise_type;
explicit ViaCoroutine( explicit ViaCoroutine(
std::experimental::coroutine_handle<promise_type> coro) noexcept std::experimental::coroutine_handle<promise_type> coro) noexcept
: coro_(coro) {} : coro_(coro) {}
static ViaCoroutine createImpl() { co_return; }
std::experimental::coroutine_handle<promise_type> coro_; std::experimental::coroutine_handle<promise_type> coro_;
}; };
} // namespace detail } // namespace detail
template <typename Awaiter> template <typename Awaitable>
class ViaIfAsyncAwaiter { class StackAwareViaIfAsyncAwaiter {
using await_suspend_result_t = using WithAsyncStackAwaitable =
decltype(std::declval<Awaiter&>().await_suspend( decltype(folly::coro::co_withAsyncStack(std::declval<Awaitable>()));
std::declval<std::experimental::coroutine_handle<>>())); using Awaiter = folly::coro::awaiter_type_t<WithAsyncStackAwaitable>;
using CoroutineType = detail::ViaCoroutine<true>;
using CoroutinePromise = typename CoroutineType::promise_type;
using WrapperHandle = std::experimental::coroutine_handle<CoroutinePromise>;
using await_suspend_result_t = decltype(
std::declval<Awaiter&>().await_suspend(std::declval<WrapperHandle>()));
public: public:
static_assert( explicit StackAwareViaIfAsyncAwaiter(
folly::coro::is_awaiter_v<Awaiter>, folly::Executor::KeepAlive<> executor,
"Awaiter type does not implement the Awaiter interface."); Awaitable&& awaitable)
: viaCoroutine_(CoroutineType::create(std::move(executor))),
awaitable_(folly::coro::co_withAsyncStack(
static_cast<Awaitable&&>(awaitable))),
awaiter_(folly::coro::get_awaiter(
static_cast<WithAsyncStackAwaitable&&>(awaitable_))) {}
template <typename Awaitable> decltype(auto) await_ready() noexcept(noexcept(awaiter_.await_ready())) {
return awaiter_.await_ready();
}
template <typename Promise>
auto await_suspend(std::experimental::coroutine_handle<Promise> h) noexcept(
noexcept(std::declval<Awaiter&>().await_suspend(
std::declval<WrapperHandle>()))) -> await_suspend_result_t {
auto& promise = h.promise();
auto& asyncFrame = promise.getAsyncFrame();
viaCoroutine_.setContinuation(h);
viaCoroutine_.setAsyncFrame(asyncFrame);
if constexpr (!detail::_is_coroutine_handle<
await_suspend_result_t>::value) {
viaCoroutine_.saveContext();
}
return awaiter_.await_suspend(viaCoroutine_.getHandle());
}
decltype(auto) await_resume() noexcept(noexcept(awaiter_.await_resume())) {
viaCoroutine_.destroy();
return awaiter_.await_resume();
}
private:
CoroutineType viaCoroutine_;
WithAsyncStackAwaitable awaitable_;
Awaiter awaiter_;
};
template <bool IsCallerAsyncStackAware, typename Awaitable>
class ViaIfAsyncAwaiter {
using Awaiter = folly::coro::awaiter_type_t<Awaitable>;
using CoroutineType = detail::ViaCoroutine<false>;
using CoroutinePromise = typename CoroutineType::promise_type;
using WrapperHandle = std::experimental::coroutine_handle<CoroutinePromise>;
using await_suspend_result_t = decltype(
std::declval<Awaiter&>().await_suspend(std::declval<WrapperHandle>()));
public:
explicit ViaIfAsyncAwaiter( explicit ViaIfAsyncAwaiter(
folly::Executor::KeepAlive<> executor, folly::Executor::KeepAlive<> executor,
Awaitable&& awaitable) Awaitable&& awaitable)
: viaCoroutine_(detail::ViaCoroutine::create(std::move(executor))), : viaCoroutine_(CoroutineType::create(std::move(executor))),
awaiter_( awaiter_(
folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {} folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {}
bool await_ready() noexcept( decltype(auto) await_ready() noexcept(noexcept(awaiter_.await_ready())) {
noexcept(std::declval<Awaiter&>().await_ready())) {
DCHECK(true);
return awaiter_.await_ready(); return awaiter_.await_ready();
} }
...@@ -212,30 +298,46 @@ class ViaIfAsyncAwaiter { ...@@ -212,30 +298,46 @@ class ViaIfAsyncAwaiter {
// correctly captures the RequestContext to get correct behaviour in this // correctly captures the RequestContext to get correct behaviour in this
// case. // case.
template < template <typename Promise>
typename Result = await_suspend_result_t, auto await_suspend(
std::enable_if_t< std::experimental::coroutine_handle<Promise>
folly::coro::detail::_is_coroutine_handle<Result>::value, continuation) noexcept(noexcept(awaiter_
int> = 0> .await_suspend(std::declval<
auto WrapperHandle>())))
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept( -> await_suspend_result_t {
noexcept(std::declval<Awaiter&>().await_suspend(continuation))) viaCoroutine_.setContinuation(continuation);
-> Result {
return awaiter_.await_suspend( if constexpr (!detail::_is_coroutine_handle<
viaCoroutine_.getWrappedCoroutine(continuation)); await_suspend_result_t>::value) {
} viaCoroutine_.saveContext();
}
template < if constexpr (IsCallerAsyncStackAware) {
typename Result = await_suspend_result_t, auto& asyncFrame = continuation.promise().getAsyncFrame();
std::enable_if_t< auto& stackRoot = *asyncFrame.getStackRoot();
!folly::coro::detail::_is_coroutine_handle<Result>::value,
int> = 0> viaCoroutine_.setAsyncFrame(asyncFrame);
auto
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept( folly::deactivateAsyncStackFrame(asyncFrame);
noexcept(std::declval<Awaiter&>().await_suspend(continuation)))
-> Result { try {
return awaiter_.await_suspend( if constexpr (std::is_same_v<await_suspend_result_t, bool>) {
viaCoroutine_.getWrappedCoroutineWithSavedContext(continuation)); if (!awaiter_.await_suspend(viaCoroutine_.getHandle())) {
// Reactivate the stack-frame before we resume.
folly::activateAsyncStackFrame(stackRoot, asyncFrame);
return false;
}
return true;
} else {
return awaiter_.await_suspend(viaCoroutine_.getHandle());
}
} catch (...) {
folly::activateAsyncStackFrame(stackRoot, asyncFrame);
throw;
}
} else {
return awaiter_.await_suspend(viaCoroutine_.getHandle());
}
} }
auto await_resume() noexcept( auto await_resume() noexcept(
...@@ -254,14 +356,15 @@ class ViaIfAsyncAwaiter { ...@@ -254,14 +356,15 @@ class ViaIfAsyncAwaiter {
return awaiter_.await_resume_try(); return awaiter_.await_resume_try();
} }
detail::ViaCoroutine viaCoroutine_; private:
CoroutineType viaCoroutine_;
Awaiter awaiter_; Awaiter awaiter_;
}; };
template <typename Awaitable> template <typename Awaitable>
class ViaIfAsyncAwaitable { class StackAwareViaIfAsyncAwaitable {
public: public:
explicit ViaIfAsyncAwaitable( explicit StackAwareViaIfAsyncAwaitable(
folly::Executor::KeepAlive<> executor, folly::Executor::KeepAlive<> executor,
Awaitable&& Awaitable&&
awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>:: awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>::
...@@ -269,23 +372,15 @@ class ViaIfAsyncAwaitable { ...@@ -269,23 +372,15 @@ class ViaIfAsyncAwaitable {
: executor_(std::move(executor)), : executor_(std::move(executor)),
awaitable_(static_cast<Awaitable&&>(awaitable)) {} awaitable_(static_cast<Awaitable&&>(awaitable)) {}
template <typename Awaitable2> auto operator co_await() && {
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>&& awaitable) if constexpr (is_awaitable_async_stack_aware_v<Awaitable>) {
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable2>>; return StackAwareViaIfAsyncAwaiter<Awaitable>{
std::move(executor_), static_cast<Awaitable&&>(awaitable_)};
template <typename Awaitable2> } else {
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>& awaitable) return ViaIfAsyncAwaiter<true, Awaitable>{
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable2&>>; std::move(executor_), static_cast<Awaitable&&>(awaitable_)};
}
template <typename Awaitable2> }
friend auto operator co_await(
const ViaIfAsyncAwaitable<Awaitable2>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&&>>;
template <typename Awaitable2>
friend auto operator co_await(
const ViaIfAsyncAwaitable<Awaitable2>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&>>;
private: private:
folly::Executor::KeepAlive<> executor_; folly::Executor::KeepAlive<> executor_;
...@@ -293,34 +388,32 @@ class ViaIfAsyncAwaitable { ...@@ -293,34 +388,32 @@ class ViaIfAsyncAwaitable {
}; };
template <typename Awaitable> template <typename Awaitable>
auto operator co_await(ViaIfAsyncAwaitable<Awaitable>&& awaitable) class ViaIfAsyncAwaitable {
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable>> { public:
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable>>{ explicit ViaIfAsyncAwaitable(
std::move(awaitable.executor_), folly::Executor::KeepAlive<> executor,
static_cast<Awaitable&&>(awaitable.awaitable_)}; Awaitable&&
} awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>::
value)
: executor_(std::move(executor)),
awaitable_(static_cast<Awaitable&&>(awaitable)) {}
template <typename Awaitable> ViaIfAsyncAwaiter<false, Awaitable> operator co_await() && {
auto operator co_await(ViaIfAsyncAwaitable<Awaitable>& awaitable) return ViaIfAsyncAwaiter<false, Awaitable>{
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable&>> { std::move(executor_), static_cast<Awaitable&&>(awaitable_)};
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable&>>{ }
awaitable.executor_, awaitable.awaitable_};
}
template <typename Awaitable> friend StackAwareViaIfAsyncAwaitable<Awaitable> tag_invoke(
auto operator co_await(const ViaIfAsyncAwaitable<Awaitable>&& awaitable) cpo_t<co_withAsyncStack>,
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&&>> { ViaIfAsyncAwaitable&& self) {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&&>>{ return StackAwareViaIfAsyncAwaitable<Awaitable>{
std::move(awaitable.executor_), std::move(self.executor_), static_cast<Awaitable&&>(self.awaitable_)};
static_cast<const Awaitable&&>(awaitable.awaitable_)}; }
}
template <typename Awaitable> private:
auto operator co_await(const ViaIfAsyncAwaitable<Awaitable>& awaitable) folly::Executor::KeepAlive<> executor_;
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&>> { Awaitable awaitable_;
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&>>{ };
awaitable.executor_, awaitable.awaitable_};
}
namespace detail { namespace detail {
...@@ -354,9 +447,6 @@ template < ...@@ -354,9 +447,6 @@ template <
int> = 0> int> = 0>
auto co_viaIfAsync(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable) auto co_viaIfAsync(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable)
-> ViaIfAsyncAwaitable<Awaitable> { -> ViaIfAsyncAwaitable<Awaitable> {
static_assert(
folly::coro::is_awaitable_v<Awaitable>,
"co_viaIfAsync() argument 2 is not awaitable.");
return ViaIfAsyncAwaitable<Awaitable>{std::move(executor), return ViaIfAsyncAwaitable<Awaitable>{std::move(executor),
static_cast<Awaitable&&>(awaitable)}; static_cast<Awaitable&&>(awaitable)};
} }
......
...@@ -206,6 +206,10 @@ struct WithAsyncStackFunction { ...@@ -206,6 +206,10 @@ struct WithAsyncStackFunction {
} // namespace detail } // namespace detail
template <typename Awaitable>
inline constexpr bool is_awaitable_async_stack_aware_v =
folly::is_tag_invocable_v<detail::WithAsyncStackFunction, Awaitable>;
// Coroutines that support the AsyncStack protocol will apply the // Coroutines that support the AsyncStack protocol will apply the
// co_withAsyncStack() customisation-point to an awaitable inside its // co_withAsyncStack() customisation-point to an awaitable inside its
// await_transform() to ensure that the current coroutine's AsyncStackFrame // await_transform() to ensure that the current coroutine's AsyncStackFrame
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment