Commit 8f26eb68 authored by Shai Szulanski's avatar Shai Szulanski Committed by Facebook GitHub Bot

Add error hint to CoroutineHandle

Summary: Provies single-await co_nothrow wrapper

Reviewed By: capickett

Differential Revision: D31450124

fbshipit-source-id: e8d18c49248d85ef864abad2eecffe69de58a546
parent f5a43a13
......@@ -45,6 +45,10 @@
#if FOLLY_HAS_COROUTINES
namespace folly {
class exception_wrapper;
}
namespace folly::coro {
#if __has_include(<coroutine>)
......@@ -170,6 +174,11 @@ class variant_awaitable : private std::variant<A...> {
class ExtendedCoroutinePromise {
public:
virtual coroutine_handle<> getHandle() = 0;
// Types may provide a more efficient resumption path when they know they will
// be receiving an error result from the awaitee.
virtual coroutine_handle<> getErrorHandle(exception_wrapper&) {
return getHandle();
}
protected:
~ExtendedCoroutinePromise() = default;
......@@ -212,6 +221,13 @@ class ExtendedCoroutineHandle {
ExtendedCoroutinePromise* getPromise() const noexcept { return extended_; }
coroutine_handle<> getErrorHandle(exception_wrapper& ex) {
if (extended_) {
return extended_->getErrorHandle(ex);
}
return basic_;
}
explicit operator bool() const noexcept { return !!extended_; }
private:
......
......@@ -66,7 +66,7 @@ class TaskPromiseBase {
template <typename Promise>
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES coroutine_handle<>
await_suspend(coroutine_handle<Promise> coro) noexcept {
TaskPromiseBase& promise = coro.promise();
auto& promise = coro.promise();
// If the continuation has been exchanged, then we expect that the
// exchanger will handle the lifetime of the async stack. See
// ScopeExitTaskPromise's FinalAwaiter for more details.
......@@ -76,6 +76,10 @@ class TaskPromiseBase {
if (promise.ownsAsyncFrame_) {
folly::popAsyncStackFrameCallee(promise.asyncFrame_);
}
if (promise.result_.hasException()) {
return promise.continuation_.getErrorHandle(
promise.result_.exception());
}
return promise.continuation_.getHandle();
}
......@@ -85,7 +89,8 @@ class TaskPromiseBase {
friend class FinalAwaiter;
protected:
TaskPromiseBase() noexcept {}
TaskPromiseBase() noexcept = default;
~TaskPromiseBase() = default;
template <typename Promise>
variant_awaitable<FinalAwaiter, ready_awaitable<>> do_safe_point(
......@@ -111,12 +116,23 @@ class TaskPromiseBase {
template <typename Awaitable>
auto await_transform(Awaitable&& awaitable) {
bypassExceptionThrowing_ =
bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED
? BypassExceptionThrowing::ACTIVE
: BypassExceptionThrowing::INACTIVE;
return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync(
executor_.get_alias(),
folly::coro::co_withCancellation(
cancelToken_, static_cast<Awaitable&&>(awaitable))));
}
template <typename Awaitable>
auto await_transform(NothrowAwaitable<Awaitable>&& awaitable) {
bypassExceptionThrowing_ = BypassExceptionThrowing::REQUESTED;
return await_transform(awaitable.unwrap());
}
auto await_transform(co_current_executor_t) noexcept {
return ready_awaitable<folly::Executor*>{executor_.get()};
}
......@@ -161,6 +177,13 @@ class TaskPromiseBase {
folly::CancellationToken cancelToken_;
bool hasCancelTokenOverride_ = false;
bool ownsAsyncFrame_ = true;
protected:
enum class BypassExceptionThrowing : uint8_t {
INACTIVE,
ACTIVE,
REQUESTED,
} bypassExceptionThrowing_{BypassExceptionThrowing::INACTIVE};
};
template <typename T>
......@@ -220,6 +243,16 @@ class TaskPromise final : public TaskPromiseBase,
return do_safe_point(*this);
}
coroutine_handle<> getErrorHandle(exception_wrapper& ex) override {
if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
auto finalAwaiter = yield_value(co_error(std::move(ex)));
DCHECK(!finalAwaiter.await_ready());
return finalAwaiter.await_suspend(
coroutine_handle<TaskPromise>::from_promise(*this));
}
return coroutine_handle<TaskPromise>::from_promise(*this);
}
private:
Try<StorageType> result_;
};
......@@ -265,6 +298,14 @@ class TaskPromise<void> final
return do_safe_point(*this);
}
coroutine_handle<> getErrorHandle(exception_wrapper& ex) override {
if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
return yield_value(co_error(std::move(ex)))
.await_suspend(coroutine_handle<TaskPromise>::from_promise(*this));
}
return coroutine_handle<TaskPromise>::from_promise(*this);
}
private:
Try<void> result_;
};
......
......@@ -39,7 +39,7 @@ namespace coro {
namespace detail {
class ViaCoroutinePromiseBase {
class ViaCoroutinePromiseBase : public ExtendedCoroutinePromise {
public:
static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size);
......@@ -61,7 +61,7 @@ class ViaCoroutinePromiseBase {
executor_ = std::move(executor);
}
void setContinuation(coroutine_handle<> continuation) noexcept {
void setContinuation(ExtendedCoroutineHandle continuation) noexcept {
continuation_ = continuation;
}
......@@ -83,15 +83,24 @@ class ViaCoroutinePromiseBase {
void executeContinuation() noexcept {
RequestContextScopeGuard contextScope{std::move(context_)};
if (asyncFrame_ != nullptr) {
folly::resumeCoroutineWithNewAsyncStackRoot(continuation_, *asyncFrame_);
folly::resumeCoroutineWithNewAsyncStackRoot(
continuation_.getHandle(), *asyncFrame_);
} else {
continuation_.resume();
}
}
public:
coroutine_handle<> getHandle() final { return continuation_.getHandle(); }
coroutine_handle<> getErrorHandle(exception_wrapper& ex) final {
return continuation_.getErrorHandle(ex);
}
protected:
virtual ~ViaCoroutinePromiseBase() = default;
folly::Executor::KeepAlive<> executor_;
coroutine_handle<> continuation_;
ExtendedCoroutineHandle continuation_;
folly::AsyncStackFrame* asyncFrame_ = nullptr;
std::shared_ptr<RequestContext> context_;
};
......@@ -99,7 +108,7 @@ class ViaCoroutinePromiseBase {
template <bool IsStackAware>
class ViaCoroutine {
public:
class promise_type : public ViaCoroutinePromiseBase {
class promise_type final : public ViaCoroutinePromiseBase {
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
......@@ -155,7 +164,7 @@ class ViaCoroutine {
coro_.promise().setExecutor(std::move(executor));
}
void setContinuation(coroutine_handle<> continuation) noexcept {
void setContinuation(ExtendedCoroutineHandle continuation) noexcept {
coro_.promise().setContinuation(continuation);
}
......@@ -610,6 +619,81 @@ using semi_await_try_result_t =
std::declval<folly::Executor::KeepAlive<>>(),
folly::coro::co_awaitTry(std::declval<T>())))>;
namespace detail {
template <typename T>
class NothrowAwaitable {
public:
template <typename T2>
explicit NothrowAwaitable(T2&& awaitable) noexcept(
std::is_nothrow_constructible_v<T, T2>)
: inner_(static_cast<T2&&>(awaitable)) {}
template <typename Factory>
explicit NothrowAwaitable(std::in_place_t, Factory&& factory)
: inner_(factory()) {}
T&& unwrap() { return std::move(inner_); }
template <
typename T2 = T,
typename Result = decltype(folly::coro::co_withCancellation(
std::declval<const folly::CancellationToken&>(), std::declval<T2>()))>
friend NothrowAwaitable<Result> co_withCancellation(
const folly::CancellationToken& cancelToken,
NothrowAwaitable&& awaitable) {
return NothrowAwaitable<Result>{std::in_place, [&]() -> decltype(auto) {
return folly::coro::co_withCancellation(
cancelToken,
static_cast<T&&>(awaitable.inner_));
}};
}
template <
typename T2 = T,
typename Result =
decltype(folly::coro::co_withAsyncStack(std::declval<T2>()))>
friend NothrowAwaitable<Result>
tag_invoke(cpo_t<co_withAsyncStack>, NothrowAwaitable&& awaitable) noexcept(
noexcept(folly::coro::co_withAsyncStack(std::declval<T2>()))) {
return NothrowAwaitable<Result>{std::in_place, [&]() -> decltype(auto) {
return folly::coro::co_withAsyncStack(
static_cast<T&&>(awaitable.inner_));
}};
}
template <
typename T2 = T,
typename Result = decltype(folly::coro::co_viaIfAsync(
std::declval<folly::Executor::KeepAlive<>>(), std::declval<T2>()))>
friend NothrowAwaitable<Result> co_viaIfAsync(
folly::Executor::KeepAlive<> executor,
NothrowAwaitable&&
awaitable) noexcept(noexcept(folly::coro::
co_viaIfAsync(
std::declval<folly::Executor::
KeepAlive<>>(),
std::declval<T2>()))) {
return NothrowAwaitable<Result>{std::in_place, [&]() -> decltype(auto) {
return folly::coro::co_viaIfAsync(
std::move(executor),
static_cast<T&&>(awaitable.inner_));
}};
}
private:
T inner_;
};
} // namespace detail
template <typename Awaitable>
detail::NothrowAwaitable<remove_cvref_t<Awaitable>> co_nothrow(
Awaitable&& awaitable) {
return detail::NothrowAwaitable<remove_cvref_t<Awaitable>>{
static_cast<Awaitable&&>(awaitable)};
}
} // namespace coro
} // namespace folly
......
......@@ -630,4 +630,72 @@ TEST_F(TaskTest, SafePoint) {
}());
}
TEST_F(TaskTest, CoAwaitNothrow) {
auto res =
folly::coro::blockingWait(co_awaitTry([]() -> folly::coro::Task<void> {
auto t = []() -> folly::coro::Task<int> { co_return 42; }();
int result = co_await folly::coro::co_nothrow(std::move(t));
EXPECT_EQ(42, result);
t = []() -> folly::coro::Task<int> {
co_yield folly::coro::co_error(std::runtime_error(""));
}();
try {
result = co_await folly::coro::co_nothrow(std::move(t));
} catch (...) {
ADD_FAILURE();
}
ADD_FAILURE();
}()));
EXPECT_TRUE(res.hasException<std::runtime_error>());
}
TEST_F(TaskTest, CoAwaitNothrowWithScheduleOn) {
auto res =
folly::coro::blockingWait(co_awaitTry([]() -> folly::coro::Task<void> {
auto t = []() -> folly::coro::Task<int> { co_return 42; }();
int result = co_await folly::coro::co_nothrow(
std::move(t).scheduleOn(folly::getGlobalCPUExecutor()));
EXPECT_EQ(42, result);
t = []() -> folly::coro::Task<int> {
co_yield folly::coro::co_error(std::runtime_error(""));
}();
try {
result = co_await folly::coro::co_nothrow(
std::move(t).scheduleOn(folly::getGlobalCPUExecutor()));
} catch (...) {
ADD_FAILURE();
}
ADD_FAILURE();
}()));
EXPECT_TRUE(res.hasException<std::runtime_error>());
}
TEST_F(TaskTest, CoAwaitThrowAfterNothrow) {
auto res =
folly::coro::blockingWait(co_awaitTry([]() -> folly::coro::Task<void> {
auto t = []() -> folly::coro::Task<int> { co_return 42; }();
int result = co_await folly::coro::co_nothrow(std::move(t));
EXPECT_EQ(42, result);
t = []() -> folly::coro::Task<int> {
co_yield folly::coro::co_error(std::runtime_error(""));
}();
try {
result = co_await std::move(t);
ADD_FAILURE();
} catch (...) {
throw std::logic_error("translated");
}
}()));
EXPECT_TRUE(res.hasException<std::logic_error>());
}
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment