Commit e8eb0719 authored by Lewis Baker's avatar Lewis Baker Committed by Facebook GitHub Bot

Add support for saving/restoring async stack frames to default co_viaIfAsync() implementation

Summary:
This should avoid creating two temporary wrapper coroutines to apply
both co_viaIfAsync() and co_withAsyncStack() to awaitables when awaited
within a Task/AsyncGenerator coroutine.

The default co_viaIfAsync() implementation now customises the
co_withAsyncStack() CPO and saves and restores the async frame
itself if the wrapped awaitable does not support async stacks,
or otherwise passes the async frame through to the child awaitable
if it does support async stacks.

Reviewed By: andriigrynenko

Differential Revision: D24464513

fbshipit-source-id: f9751030f3c1d2725a00cb803ee46ee12dddb2ed
parent beb3e640
......@@ -22,10 +22,12 @@
#include <folly/Executor.h>
#include <folly/Traits.h>
#include <folly/experimental/coro/Traits.h>
#include <folly/experimental/coro/WithAsyncStack.h>
#include <folly/experimental/coro/WithCancellation.h>
#include <folly/experimental/coro/detail/Malloc.h>
#include <folly/io/async/Request.h>
#include <folly/lang/CustomizationPoint.h>
#include <folly/tracing/AsyncStack.h>
#include <glog/logging.h>
......@@ -37,14 +39,8 @@ namespace coro {
namespace detail {
class ViaCoroutine {
public:
class promise_type {
class ViaCoroutinePromiseBase {
public:
// Passed as lvalue by compiler, but should have no other dependencies
promise_type(folly::Executor::KeepAlive<>& executor) noexcept
: executor_(std::move(executor)) {}
static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size);
}
......@@ -53,134 +49,224 @@ class ViaCoroutine {
::folly_coro_async_free(ptr, size);
}
ViaCoroutine get_return_object() noexcept {
return ViaCoroutine{
std::experimental::coroutine_handle<promise_type>::from_promise(
*this)};
}
std::experimental::suspend_always initial_suspend() noexcept { return {}; }
auto final_suspend() noexcept {
struct Awaiter {
bool await_ready() noexcept { return false; }
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES void await_suspend(
std::experimental::coroutine_handle<promise_type> coro) noexcept {
// Schedule resumption of the coroutine on the executor.
auto& promise = coro.promise();
if (!promise.context_) {
promise.context_ = RequestContext::saveContext();
}
promise.executor_->add([&promise]() noexcept {
RequestContextScopeGuard contextScope{std::move(promise.context_)};
promise.continuation_.resume();
});
}
void await_resume() noexcept {}
};
return Awaiter{};
}
void return_void() noexcept {}
[[noreturn]] void unhandled_exception() noexcept {
LOG(FATAL) << "ViaCoroutine threw an unhandled exception";
folly::assume_unreachable();
}
void return_void() noexcept {}
void setExecutor(folly::Executor::KeepAlive<> executor) noexcept {
executor_ = std::move(executor);
}
void setContinuation(
std::experimental::coroutine_handle<> continuation) noexcept {
DCHECK(!continuation_);
continuation_ = continuation;
}
void setContext(std::shared_ptr<RequestContext> context) noexcept {
void setAsyncFrame(folly::AsyncStackFrame& frame) noexcept {
asyncFrame_ = &frame;
}
void setRequestContext(
std::shared_ptr<folly::RequestContext> context) noexcept {
context_ = std::move(context);
}
protected:
void scheduleContinuation() noexcept {
executor_->add([this]() noexcept { this->executeContinuation(); });
}
private:
void executeContinuation() noexcept {
RequestContextScopeGuard contextScope{std::move(context_)};
if (asyncFrame_ != nullptr) {
folly::resumeCoroutineWithNewAsyncStackRoot(continuation_, *asyncFrame_);
} else {
continuation_.resume();
}
}
protected:
folly::Executor::KeepAlive<> executor_;
std::experimental::coroutine_handle<> continuation_;
folly::AsyncStackFrame* asyncFrame_ = nullptr;
std::shared_ptr<RequestContext> context_;
};
template <bool IsStackAware>
class ViaCoroutine {
public:
class promise_type : public ViaCoroutinePromiseBase {
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES void await_suspend(
std::experimental::coroutine_handle<promise_type> h) noexcept {
auto& promise = h.promise();
if (!promise.context_) {
promise.setRequestContext(RequestContext::saveContext());
}
if constexpr (IsStackAware) {
folly::deactivateAsyncStackFrame(promise.getAsyncFrame());
}
promise.scheduleContinuation();
}
[[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); }
};
public:
ViaCoroutine get_return_object() noexcept {
return ViaCoroutine{
std::experimental::coroutine_handle<promise_type>::from_promise(
*this)};
}
FinalAwaiter final_suspend() noexcept { return {}; }
template <
bool IsStackAware2 = IsStackAware,
std::enable_if_t<IsStackAware2, int> = 0>
folly::AsyncStackFrame& getAsyncFrame() noexcept {
DCHECK(this->asyncFrame_ != nullptr);
return *this->asyncFrame_;
}
};
ViaCoroutine(ViaCoroutine&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~ViaCoroutine() { destroy(); }
~ViaCoroutine() {
if (coro_) {
coro_.destroy();
}
}
ViaCoroutine& operator=(ViaCoroutine other) noexcept {
swap(other);
return *this;
static ViaCoroutine create(folly::Executor::KeepAlive<> executor) {
ViaCoroutine coroutine = createImpl();
coroutine.setExecutor(std::move(executor));
return coroutine;
}
void swap(ViaCoroutine& other) noexcept { std::swap(coro_, other.coro_); }
void setExecutor(folly::Executor::KeepAlive<> executor) noexcept {
coro_.promise().setExecutor(std::move(executor));
}
std::experimental::coroutine_handle<> getWrappedCoroutine(
void setContinuation(
std::experimental::coroutine_handle<> continuation) noexcept {
if (coro_) {
coro_.promise().setContinuation(continuation);
return coro_;
} else {
return continuation;
}
}
std::experimental::coroutine_handle<> getWrappedCoroutineWithSavedContext(
std::experimental::coroutine_handle<> continuation) noexcept {
coro_.promise().setContext(RequestContext::saveContext());
return getWrappedCoroutine(continuation);
void setAsyncFrame(folly::AsyncStackFrame& frame) noexcept {
coro_.promise().setAsyncFrame(frame);
}
void destroy() {
void destroy() noexcept {
if (coro_) {
std::exchange(coro_, {}).destroy();
}
}
static ViaCoroutine create(folly::Executor::KeepAlive<> executor) {
co_return;
void saveContext() noexcept {
coro_.promise().setRequestContext(folly::RequestContext::saveContext());
}
static ViaCoroutine createInline() noexcept {
return ViaCoroutine{std::experimental::coroutine_handle<promise_type>{}};
std::experimental::coroutine_handle<promise_type> getHandle() noexcept {
return coro_;
}
private:
friend class promise_type;
explicit ViaCoroutine(
std::experimental::coroutine_handle<promise_type> coro) noexcept
: coro_(coro) {}
static ViaCoroutine createImpl() { co_return; }
std::experimental::coroutine_handle<promise_type> coro_;
};
} // namespace detail
template <typename Awaiter>
class ViaIfAsyncAwaiter {
using await_suspend_result_t =
decltype(std::declval<Awaiter&>().await_suspend(
std::declval<std::experimental::coroutine_handle<>>()));
template <typename Awaitable>
class StackAwareViaIfAsyncAwaiter {
using WithAsyncStackAwaitable =
decltype(folly::coro::co_withAsyncStack(std::declval<Awaitable>()));
using Awaiter = folly::coro::awaiter_type_t<WithAsyncStackAwaitable>;
using CoroutineType = detail::ViaCoroutine<true>;
using CoroutinePromise = typename CoroutineType::promise_type;
using WrapperHandle = std::experimental::coroutine_handle<CoroutinePromise>;
using await_suspend_result_t = decltype(
std::declval<Awaiter&>().await_suspend(std::declval<WrapperHandle>()));
public:
static_assert(
folly::coro::is_awaiter_v<Awaiter>,
"Awaiter type does not implement the Awaiter interface.");
explicit StackAwareViaIfAsyncAwaiter(
folly::Executor::KeepAlive<> executor,
Awaitable&& awaitable)
: viaCoroutine_(CoroutineType::create(std::move(executor))),
awaitable_(folly::coro::co_withAsyncStack(
static_cast<Awaitable&&>(awaitable))),
awaiter_(folly::coro::get_awaiter(
static_cast<WithAsyncStackAwaitable&&>(awaitable_))) {}
template <typename Awaitable>
decltype(auto) await_ready() noexcept(noexcept(awaiter_.await_ready())) {
return awaiter_.await_ready();
}
template <typename Promise>
auto await_suspend(std::experimental::coroutine_handle<Promise> h) noexcept(
noexcept(std::declval<Awaiter&>().await_suspend(
std::declval<WrapperHandle>()))) -> await_suspend_result_t {
auto& promise = h.promise();
auto& asyncFrame = promise.getAsyncFrame();
viaCoroutine_.setContinuation(h);
viaCoroutine_.setAsyncFrame(asyncFrame);
if constexpr (!detail::_is_coroutine_handle<
await_suspend_result_t>::value) {
viaCoroutine_.saveContext();
}
return awaiter_.await_suspend(viaCoroutine_.getHandle());
}
decltype(auto) await_resume() noexcept(noexcept(awaiter_.await_resume())) {
viaCoroutine_.destroy();
return awaiter_.await_resume();
}
private:
CoroutineType viaCoroutine_;
WithAsyncStackAwaitable awaitable_;
Awaiter awaiter_;
};
template <bool IsCallerAsyncStackAware, typename Awaitable>
class ViaIfAsyncAwaiter {
using Awaiter = folly::coro::awaiter_type_t<Awaitable>;
using CoroutineType = detail::ViaCoroutine<false>;
using CoroutinePromise = typename CoroutineType::promise_type;
using WrapperHandle = std::experimental::coroutine_handle<CoroutinePromise>;
using await_suspend_result_t = decltype(
std::declval<Awaiter&>().await_suspend(std::declval<WrapperHandle>()));
public:
explicit ViaIfAsyncAwaiter(
folly::Executor::KeepAlive<> executor,
Awaitable&& awaitable)
: viaCoroutine_(detail::ViaCoroutine::create(std::move(executor))),
: viaCoroutine_(CoroutineType::create(std::move(executor))),
awaiter_(
folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {}
bool await_ready() noexcept(
noexcept(std::declval<Awaiter&>().await_ready())) {
DCHECK(true);
decltype(auto) await_ready() noexcept(noexcept(awaiter_.await_ready())) {
return awaiter_.await_ready();
}
......@@ -212,30 +298,46 @@ class ViaIfAsyncAwaiter {
// correctly captures the RequestContext to get correct behaviour in this
// case.
template <
typename Result = await_suspend_result_t,
std::enable_if_t<
folly::coro::detail::_is_coroutine_handle<Result>::value,
int> = 0>
auto
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept(
noexcept(std::declval<Awaiter&>().await_suspend(continuation)))
-> Result {
return awaiter_.await_suspend(
viaCoroutine_.getWrappedCoroutine(continuation));
template <typename Promise>
auto await_suspend(
std::experimental::coroutine_handle<Promise>
continuation) noexcept(noexcept(awaiter_
.await_suspend(std::declval<
WrapperHandle>())))
-> await_suspend_result_t {
viaCoroutine_.setContinuation(continuation);
if constexpr (!detail::_is_coroutine_handle<
await_suspend_result_t>::value) {
viaCoroutine_.saveContext();
}
template <
typename Result = await_suspend_result_t,
std::enable_if_t<
!folly::coro::detail::_is_coroutine_handle<Result>::value,
int> = 0>
auto
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept(
noexcept(std::declval<Awaiter&>().await_suspend(continuation)))
-> Result {
return awaiter_.await_suspend(
viaCoroutine_.getWrappedCoroutineWithSavedContext(continuation));
if constexpr (IsCallerAsyncStackAware) {
auto& asyncFrame = continuation.promise().getAsyncFrame();
auto& stackRoot = *asyncFrame.getStackRoot();
viaCoroutine_.setAsyncFrame(asyncFrame);
folly::deactivateAsyncStackFrame(asyncFrame);
try {
if constexpr (std::is_same_v<await_suspend_result_t, bool>) {
if (!awaiter_.await_suspend(viaCoroutine_.getHandle())) {
// Reactivate the stack-frame before we resume.
folly::activateAsyncStackFrame(stackRoot, asyncFrame);
return false;
}
return true;
} else {
return awaiter_.await_suspend(viaCoroutine_.getHandle());
}
} catch (...) {
folly::activateAsyncStackFrame(stackRoot, asyncFrame);
throw;
}
} else {
return awaiter_.await_suspend(viaCoroutine_.getHandle());
}
}
auto await_resume() noexcept(
......@@ -254,14 +356,15 @@ class ViaIfAsyncAwaiter {
return awaiter_.await_resume_try();
}
detail::ViaCoroutine viaCoroutine_;
private:
CoroutineType viaCoroutine_;
Awaiter awaiter_;
};
template <typename Awaitable>
class ViaIfAsyncAwaitable {
class StackAwareViaIfAsyncAwaitable {
public:
explicit ViaIfAsyncAwaitable(
explicit StackAwareViaIfAsyncAwaitable(
folly::Executor::KeepAlive<> executor,
Awaitable&&
awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>::
......@@ -269,23 +372,15 @@ class ViaIfAsyncAwaitable {
: executor_(std::move(executor)),
awaitable_(static_cast<Awaitable&&>(awaitable)) {}
template <typename Awaitable2>
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable2>>;
template <typename Awaitable2>
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable2&>>;
template <typename Awaitable2>
friend auto operator co_await(
const ViaIfAsyncAwaitable<Awaitable2>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&&>>;
template <typename Awaitable2>
friend auto operator co_await(
const ViaIfAsyncAwaitable<Awaitable2>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&>>;
auto operator co_await() && {
if constexpr (is_awaitable_async_stack_aware_v<Awaitable>) {
return StackAwareViaIfAsyncAwaiter<Awaitable>{
std::move(executor_), static_cast<Awaitable&&>(awaitable_)};
} else {
return ViaIfAsyncAwaiter<true, Awaitable>{
std::move(executor_), static_cast<Awaitable&&>(awaitable_)};
}
}
private:
folly::Executor::KeepAlive<> executor_;
......@@ -293,34 +388,32 @@ class ViaIfAsyncAwaitable {
};
template <typename Awaitable>
auto operator co_await(ViaIfAsyncAwaitable<Awaitable>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable>>{
std::move(awaitable.executor_),
static_cast<Awaitable&&>(awaitable.awaitable_)};
}
class ViaIfAsyncAwaitable {
public:
explicit ViaIfAsyncAwaitable(
folly::Executor::KeepAlive<> executor,
Awaitable&&
awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>::
value)
: executor_(std::move(executor)),
awaitable_(static_cast<Awaitable&&>(awaitable)) {}
template <typename Awaitable>
auto operator co_await(ViaIfAsyncAwaitable<Awaitable>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable&>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable&>>{
awaitable.executor_, awaitable.awaitable_};
}
ViaIfAsyncAwaiter<false, Awaitable> operator co_await() && {
return ViaIfAsyncAwaiter<false, Awaitable>{
std::move(executor_), static_cast<Awaitable&&>(awaitable_)};
}
template <typename Awaitable>
auto operator co_await(const ViaIfAsyncAwaitable<Awaitable>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&&>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&&>>{
std::move(awaitable.executor_),
static_cast<const Awaitable&&>(awaitable.awaitable_)};
}
friend StackAwareViaIfAsyncAwaitable<Awaitable> tag_invoke(
cpo_t<co_withAsyncStack>,
ViaIfAsyncAwaitable&& self) {
return StackAwareViaIfAsyncAwaitable<Awaitable>{
std::move(self.executor_), static_cast<Awaitable&&>(self.awaitable_)};
}
template <typename Awaitable>
auto operator co_await(const ViaIfAsyncAwaitable<Awaitable>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&>>{
awaitable.executor_, awaitable.awaitable_};
}
private:
folly::Executor::KeepAlive<> executor_;
Awaitable awaitable_;
};
namespace detail {
......@@ -354,9 +447,6 @@ template <
int> = 0>
auto co_viaIfAsync(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable)
-> ViaIfAsyncAwaitable<Awaitable> {
static_assert(
folly::coro::is_awaitable_v<Awaitable>,
"co_viaIfAsync() argument 2 is not awaitable.");
return ViaIfAsyncAwaitable<Awaitable>{std::move(executor),
static_cast<Awaitable&&>(awaitable)};
}
......
......@@ -206,6 +206,10 @@ struct WithAsyncStackFunction {
} // namespace detail
template <typename Awaitable>
inline constexpr bool is_awaitable_async_stack_aware_v =
folly::is_tag_invocable_v<detail::WithAsyncStackFunction, Awaitable>;
// Coroutines that support the AsyncStack protocol will apply the
// co_withAsyncStack() customisation-point to an awaitable inside its
// await_transform() to ensure that the current coroutine's AsyncStackFrame
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment