Commit 8f26eb68 authored by Shai Szulanski's avatar Shai Szulanski Committed by Facebook GitHub Bot

Add error hint to CoroutineHandle

Summary: Provies single-await co_nothrow wrapper

Reviewed By: capickett

Differential Revision: D31450124

fbshipit-source-id: e8d18c49248d85ef864abad2eecffe69de58a546
parent f5a43a13
...@@ -45,6 +45,10 @@ ...@@ -45,6 +45,10 @@
#if FOLLY_HAS_COROUTINES #if FOLLY_HAS_COROUTINES
namespace folly {
class exception_wrapper;
}
namespace folly::coro { namespace folly::coro {
#if __has_include(<coroutine>) #if __has_include(<coroutine>)
...@@ -170,6 +174,11 @@ class variant_awaitable : private std::variant<A...> { ...@@ -170,6 +174,11 @@ class variant_awaitable : private std::variant<A...> {
class ExtendedCoroutinePromise { class ExtendedCoroutinePromise {
public: public:
virtual coroutine_handle<> getHandle() = 0; virtual coroutine_handle<> getHandle() = 0;
// Types may provide a more efficient resumption path when they know they will
// be receiving an error result from the awaitee.
virtual coroutine_handle<> getErrorHandle(exception_wrapper&) {
return getHandle();
}
protected: protected:
~ExtendedCoroutinePromise() = default; ~ExtendedCoroutinePromise() = default;
...@@ -212,6 +221,13 @@ class ExtendedCoroutineHandle { ...@@ -212,6 +221,13 @@ class ExtendedCoroutineHandle {
ExtendedCoroutinePromise* getPromise() const noexcept { return extended_; } ExtendedCoroutinePromise* getPromise() const noexcept { return extended_; }
coroutine_handle<> getErrorHandle(exception_wrapper& ex) {
if (extended_) {
return extended_->getErrorHandle(ex);
}
return basic_;
}
explicit operator bool() const noexcept { return !!extended_; } explicit operator bool() const noexcept { return !!extended_; }
private: private:
......
...@@ -66,7 +66,7 @@ class TaskPromiseBase { ...@@ -66,7 +66,7 @@ class TaskPromiseBase {
template <typename Promise> template <typename Promise>
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES coroutine_handle<> FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES coroutine_handle<>
await_suspend(coroutine_handle<Promise> coro) noexcept { await_suspend(coroutine_handle<Promise> coro) noexcept {
TaskPromiseBase& promise = coro.promise(); auto& promise = coro.promise();
// If the continuation has been exchanged, then we expect that the // If the continuation has been exchanged, then we expect that the
// exchanger will handle the lifetime of the async stack. See // exchanger will handle the lifetime of the async stack. See
// ScopeExitTaskPromise's FinalAwaiter for more details. // ScopeExitTaskPromise's FinalAwaiter for more details.
...@@ -76,6 +76,10 @@ class TaskPromiseBase { ...@@ -76,6 +76,10 @@ class TaskPromiseBase {
if (promise.ownsAsyncFrame_) { if (promise.ownsAsyncFrame_) {
folly::popAsyncStackFrameCallee(promise.asyncFrame_); folly::popAsyncStackFrameCallee(promise.asyncFrame_);
} }
if (promise.result_.hasException()) {
return promise.continuation_.getErrorHandle(
promise.result_.exception());
}
return promise.continuation_.getHandle(); return promise.continuation_.getHandle();
} }
...@@ -85,7 +89,8 @@ class TaskPromiseBase { ...@@ -85,7 +89,8 @@ class TaskPromiseBase {
friend class FinalAwaiter; friend class FinalAwaiter;
protected: protected:
TaskPromiseBase() noexcept {} TaskPromiseBase() noexcept = default;
~TaskPromiseBase() = default;
template <typename Promise> template <typename Promise>
variant_awaitable<FinalAwaiter, ready_awaitable<>> do_safe_point( variant_awaitable<FinalAwaiter, ready_awaitable<>> do_safe_point(
...@@ -111,12 +116,23 @@ class TaskPromiseBase { ...@@ -111,12 +116,23 @@ class TaskPromiseBase {
template <typename Awaitable> template <typename Awaitable>
auto await_transform(Awaitable&& awaitable) { auto await_transform(Awaitable&& awaitable) {
bypassExceptionThrowing_ =
bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED
? BypassExceptionThrowing::ACTIVE
: BypassExceptionThrowing::INACTIVE;
return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync( return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync(
executor_.get_alias(), executor_.get_alias(),
folly::coro::co_withCancellation( folly::coro::co_withCancellation(
cancelToken_, static_cast<Awaitable&&>(awaitable)))); cancelToken_, static_cast<Awaitable&&>(awaitable))));
} }
template <typename Awaitable>
auto await_transform(NothrowAwaitable<Awaitable>&& awaitable) {
bypassExceptionThrowing_ = BypassExceptionThrowing::REQUESTED;
return await_transform(awaitable.unwrap());
}
auto await_transform(co_current_executor_t) noexcept { auto await_transform(co_current_executor_t) noexcept {
return ready_awaitable<folly::Executor*>{executor_.get()}; return ready_awaitable<folly::Executor*>{executor_.get()};
} }
...@@ -161,6 +177,13 @@ class TaskPromiseBase { ...@@ -161,6 +177,13 @@ class TaskPromiseBase {
folly::CancellationToken cancelToken_; folly::CancellationToken cancelToken_;
bool hasCancelTokenOverride_ = false; bool hasCancelTokenOverride_ = false;
bool ownsAsyncFrame_ = true; bool ownsAsyncFrame_ = true;
protected:
enum class BypassExceptionThrowing : uint8_t {
INACTIVE,
ACTIVE,
REQUESTED,
} bypassExceptionThrowing_{BypassExceptionThrowing::INACTIVE};
}; };
template <typename T> template <typename T>
...@@ -220,6 +243,16 @@ class TaskPromise final : public TaskPromiseBase, ...@@ -220,6 +243,16 @@ class TaskPromise final : public TaskPromiseBase,
return do_safe_point(*this); return do_safe_point(*this);
} }
coroutine_handle<> getErrorHandle(exception_wrapper& ex) override {
if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
auto finalAwaiter = yield_value(co_error(std::move(ex)));
DCHECK(!finalAwaiter.await_ready());
return finalAwaiter.await_suspend(
coroutine_handle<TaskPromise>::from_promise(*this));
}
return coroutine_handle<TaskPromise>::from_promise(*this);
}
private: private:
Try<StorageType> result_; Try<StorageType> result_;
}; };
...@@ -265,6 +298,14 @@ class TaskPromise<void> final ...@@ -265,6 +298,14 @@ class TaskPromise<void> final
return do_safe_point(*this); return do_safe_point(*this);
} }
coroutine_handle<> getErrorHandle(exception_wrapper& ex) override {
if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
return yield_value(co_error(std::move(ex)))
.await_suspend(coroutine_handle<TaskPromise>::from_promise(*this));
}
return coroutine_handle<TaskPromise>::from_promise(*this);
}
private: private:
Try<void> result_; Try<void> result_;
}; };
......
...@@ -39,7 +39,7 @@ namespace coro { ...@@ -39,7 +39,7 @@ namespace coro {
namespace detail { namespace detail {
class ViaCoroutinePromiseBase { class ViaCoroutinePromiseBase : public ExtendedCoroutinePromise {
public: public:
static void* operator new(std::size_t size) { static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size); return ::folly_coro_async_malloc(size);
...@@ -61,7 +61,7 @@ class ViaCoroutinePromiseBase { ...@@ -61,7 +61,7 @@ class ViaCoroutinePromiseBase {
executor_ = std::move(executor); executor_ = std::move(executor);
} }
void setContinuation(coroutine_handle<> continuation) noexcept { void setContinuation(ExtendedCoroutineHandle continuation) noexcept {
continuation_ = continuation; continuation_ = continuation;
} }
...@@ -83,15 +83,24 @@ class ViaCoroutinePromiseBase { ...@@ -83,15 +83,24 @@ class ViaCoroutinePromiseBase {
void executeContinuation() noexcept { void executeContinuation() noexcept {
RequestContextScopeGuard contextScope{std::move(context_)}; RequestContextScopeGuard contextScope{std::move(context_)};
if (asyncFrame_ != nullptr) { if (asyncFrame_ != nullptr) {
folly::resumeCoroutineWithNewAsyncStackRoot(continuation_, *asyncFrame_); folly::resumeCoroutineWithNewAsyncStackRoot(
continuation_.getHandle(), *asyncFrame_);
} else { } else {
continuation_.resume(); continuation_.resume();
} }
} }
public:
coroutine_handle<> getHandle() final { return continuation_.getHandle(); }
coroutine_handle<> getErrorHandle(exception_wrapper& ex) final {
return continuation_.getErrorHandle(ex);
}
protected: protected:
virtual ~ViaCoroutinePromiseBase() = default;
folly::Executor::KeepAlive<> executor_; folly::Executor::KeepAlive<> executor_;
coroutine_handle<> continuation_; ExtendedCoroutineHandle continuation_;
folly::AsyncStackFrame* asyncFrame_ = nullptr; folly::AsyncStackFrame* asyncFrame_ = nullptr;
std::shared_ptr<RequestContext> context_; std::shared_ptr<RequestContext> context_;
}; };
...@@ -99,7 +108,7 @@ class ViaCoroutinePromiseBase { ...@@ -99,7 +108,7 @@ class ViaCoroutinePromiseBase {
template <bool IsStackAware> template <bool IsStackAware>
class ViaCoroutine { class ViaCoroutine {
public: public:
class promise_type : public ViaCoroutinePromiseBase { class promise_type final : public ViaCoroutinePromiseBase {
struct FinalAwaiter { struct FinalAwaiter {
bool await_ready() noexcept { return false; } bool await_ready() noexcept { return false; }
...@@ -155,7 +164,7 @@ class ViaCoroutine { ...@@ -155,7 +164,7 @@ class ViaCoroutine {
coro_.promise().setExecutor(std::move(executor)); coro_.promise().setExecutor(std::move(executor));
} }
void setContinuation(coroutine_handle<> continuation) noexcept { void setContinuation(ExtendedCoroutineHandle continuation) noexcept {
coro_.promise().setContinuation(continuation); coro_.promise().setContinuation(continuation);
} }
...@@ -610,6 +619,81 @@ using semi_await_try_result_t = ...@@ -610,6 +619,81 @@ using semi_await_try_result_t =
std::declval<folly::Executor::KeepAlive<>>(), std::declval<folly::Executor::KeepAlive<>>(),
folly::coro::co_awaitTry(std::declval<T>())))>; folly::coro::co_awaitTry(std::declval<T>())))>;
namespace detail {
template <typename T>
class NothrowAwaitable {
public:
template <typename T2>
explicit NothrowAwaitable(T2&& awaitable) noexcept(
std::is_nothrow_constructible_v<T, T2>)
: inner_(static_cast<T2&&>(awaitable)) {}
template <typename Factory>
explicit NothrowAwaitable(std::in_place_t, Factory&& factory)
: inner_(factory()) {}
T&& unwrap() { return std::move(inner_); }
template <
typename T2 = T,
typename Result = decltype(folly::coro::co_withCancellation(
std::declval<const folly::CancellationToken&>(), std::declval<T2>()))>
friend NothrowAwaitable<Result> co_withCancellation(
const folly::CancellationToken& cancelToken,
NothrowAwaitable&& awaitable) {
return NothrowAwaitable<Result>{std::in_place, [&]() -> decltype(auto) {
return folly::coro::co_withCancellation(
cancelToken,
static_cast<T&&>(awaitable.inner_));
}};
}
template <
typename T2 = T,
typename Result =
decltype(folly::coro::co_withAsyncStack(std::declval<T2>()))>
friend NothrowAwaitable<Result>
tag_invoke(cpo_t<co_withAsyncStack>, NothrowAwaitable&& awaitable) noexcept(
noexcept(folly::coro::co_withAsyncStack(std::declval<T2>()))) {
return NothrowAwaitable<Result>{std::in_place, [&]() -> decltype(auto) {
return folly::coro::co_withAsyncStack(
static_cast<T&&>(awaitable.inner_));
}};
}
template <
typename T2 = T,
typename Result = decltype(folly::coro::co_viaIfAsync(
std::declval<folly::Executor::KeepAlive<>>(), std::declval<T2>()))>
friend NothrowAwaitable<Result> co_viaIfAsync(
folly::Executor::KeepAlive<> executor,
NothrowAwaitable&&
awaitable) noexcept(noexcept(folly::coro::
co_viaIfAsync(
std::declval<folly::Executor::
KeepAlive<>>(),
std::declval<T2>()))) {
return NothrowAwaitable<Result>{std::in_place, [&]() -> decltype(auto) {
return folly::coro::co_viaIfAsync(
std::move(executor),
static_cast<T&&>(awaitable.inner_));
}};
}
private:
T inner_;
};
} // namespace detail
template <typename Awaitable>
detail::NothrowAwaitable<remove_cvref_t<Awaitable>> co_nothrow(
Awaitable&& awaitable) {
return detail::NothrowAwaitable<remove_cvref_t<Awaitable>>{
static_cast<Awaitable&&>(awaitable)};
}
} // namespace coro } // namespace coro
} // namespace folly } // namespace folly
......
...@@ -630,4 +630,72 @@ TEST_F(TaskTest, SafePoint) { ...@@ -630,4 +630,72 @@ TEST_F(TaskTest, SafePoint) {
}()); }());
} }
TEST_F(TaskTest, CoAwaitNothrow) {
auto res =
folly::coro::blockingWait(co_awaitTry([]() -> folly::coro::Task<void> {
auto t = []() -> folly::coro::Task<int> { co_return 42; }();
int result = co_await folly::coro::co_nothrow(std::move(t));
EXPECT_EQ(42, result);
t = []() -> folly::coro::Task<int> {
co_yield folly::coro::co_error(std::runtime_error(""));
}();
try {
result = co_await folly::coro::co_nothrow(std::move(t));
} catch (...) {
ADD_FAILURE();
}
ADD_FAILURE();
}()));
EXPECT_TRUE(res.hasException<std::runtime_error>());
}
TEST_F(TaskTest, CoAwaitNothrowWithScheduleOn) {
auto res =
folly::coro::blockingWait(co_awaitTry([]() -> folly::coro::Task<void> {
auto t = []() -> folly::coro::Task<int> { co_return 42; }();
int result = co_await folly::coro::co_nothrow(
std::move(t).scheduleOn(folly::getGlobalCPUExecutor()));
EXPECT_EQ(42, result);
t = []() -> folly::coro::Task<int> {
co_yield folly::coro::co_error(std::runtime_error(""));
}();
try {
result = co_await folly::coro::co_nothrow(
std::move(t).scheduleOn(folly::getGlobalCPUExecutor()));
} catch (...) {
ADD_FAILURE();
}
ADD_FAILURE();
}()));
EXPECT_TRUE(res.hasException<std::runtime_error>());
}
TEST_F(TaskTest, CoAwaitThrowAfterNothrow) {
auto res =
folly::coro::blockingWait(co_awaitTry([]() -> folly::coro::Task<void> {
auto t = []() -> folly::coro::Task<int> { co_return 42; }();
int result = co_await folly::coro::co_nothrow(std::move(t));
EXPECT_EQ(42, result);
t = []() -> folly::coro::Task<int> {
co_yield folly::coro::co_error(std::runtime_error(""));
}();
try {
result = co_await std::move(t);
ADD_FAILURE();
} catch (...) {
throw std::logic_error("translated");
}
}()));
EXPECT_TRUE(res.hasException<std::logic_error>());
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment