Commit 33421a75 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Use KeepAlive in co_viaIfAsync

Summary: This allows us to always use a dummy keep-alive for all awaitables from coro::Task. We also don't need a custom await_transform for coro::Task anymore.

Reviewed By: lewissbaker

Differential Revision: D14491068

fbshipit-source-id: 0e931c2ae9f52da770b08207e551401e4a3778cb
parent 33cf2370
...@@ -298,8 +298,8 @@ class SharedMutexFair { ...@@ -298,8 +298,8 @@ class SharedMutexFair {
public: public:
explicit LockOperation(SharedMutexFair& mutex) noexcept : mutex_(mutex) {} explicit LockOperation(SharedMutexFair& mutex) noexcept : mutex_(mutex) {}
auto viaIfAsync(folly::Executor* executor) const { auto viaIfAsync(folly::Executor::KeepAlive<> executor) const {
return folly::coro::co_viaIfAsync(executor, Awaiter{mutex_}); return folly::coro::co_viaIfAsync(std::move(executor), Awaiter{mutex_});
} }
private: private:
......
...@@ -84,13 +84,10 @@ class TaskPromiseBase { ...@@ -84,13 +84,10 @@ class TaskPromiseBase {
return {}; return {};
} }
template <typename U>
auto await_transform(Task<U>&& t) noexcept;
template <typename Awaitable> template <typename Awaitable>
auto await_transform(Awaitable&& awaitable) noexcept { auto await_transform(Awaitable&& awaitable) noexcept {
return folly::coro::co_viaIfAsync( return folly::coro::co_viaIfAsync(
executor_.get(), static_cast<Awaitable&&>(awaitable)); executor_.copyDummy(), static_cast<Awaitable&&>(awaitable));
} }
auto await_transform(co_current_executor_t) noexcept { auto await_transform(co_current_executor_t) noexcept {
...@@ -364,9 +361,11 @@ class FOLLY_NODISCARD Task { ...@@ -364,9 +361,11 @@ class FOLLY_NODISCARD Task {
}); });
} }
friend auto co_viaIfAsync(Executor* executor, Task<T>&& t) noexcept { friend auto co_viaIfAsync(
Executor::KeepAlive<> executor,
Task<T>&& t) noexcept {
// Child task inherits the awaiting task's executor // Child task inherits the awaiting task's executor
t.coro_.promise().executor_ = getKeepAliveToken(executor); t.coro_.promise().executor_ = std::move(executor);
return Awaiter{std::exchange(t.coro_, {})}; return Awaiter{std::exchange(t.coro_, {})};
} }
...@@ -414,13 +413,6 @@ class FOLLY_NODISCARD Task { ...@@ -414,13 +413,6 @@ class FOLLY_NODISCARD Task {
handle_t coro_; handle_t coro_;
}; };
template <typename T>
auto detail::TaskPromiseBase::await_transform(Task<T>&& t) noexcept {
// Child task inherits the awaiting task's executor
t.coro_.promise().executor_ = executor_.copyDummy();
return typename Task<T>::Awaiter{std::exchange(t.coro_, {})};
}
template <typename T> template <typename T>
Task<T> detail::TaskPromise<T>::get_return_object() noexcept { Task<T> detail::TaskPromise<T>::get_return_object() noexcept {
return Task<T>{ return Task<T>{
......
...@@ -38,7 +38,8 @@ class ViaCoroutine { ...@@ -38,7 +38,8 @@ class ViaCoroutine {
public: public:
class promise_type { class promise_type {
public: public:
promise_type(folly::Executor* executor) noexcept : executor_(executor) {} promise_type(folly::Executor::KeepAlive<> executor) noexcept
: executor_(std::move(executor)) {}
ViaCoroutine get_return_object() noexcept { ViaCoroutine get_return_object() noexcept {
return ViaCoroutine{ return ViaCoroutine{
...@@ -91,7 +92,7 @@ class ViaCoroutine { ...@@ -91,7 +92,7 @@ class ViaCoroutine {
} }
private: private:
folly::Executor* executor_; folly::Executor::KeepAlive<> executor_;
std::experimental::coroutine_handle<> continuation_; std::experimental::coroutine_handle<> continuation_;
std::shared_ptr<RequestContext> context_; std::shared_ptr<RequestContext> context_;
}; };
...@@ -134,7 +135,7 @@ class ViaCoroutine { ...@@ -134,7 +135,7 @@ class ViaCoroutine {
} }
} }
static ViaCoroutine create(folly::Executor* executor) { static ViaCoroutine create(folly::Executor::KeepAlive<> executor) {
co_return; co_return;
} }
...@@ -166,13 +167,9 @@ class ViaIfAsyncAwaiter { ...@@ -166,13 +167,9 @@ class ViaIfAsyncAwaiter {
"Awaiter type does not implement the Awaiter interface."); "Awaiter type does not implement the Awaiter interface.");
template <typename Awaitable> template <typename Awaitable>
explicit ViaIfAsyncAwaiter(folly::InlineExecutor*, Awaitable&& awaitable) explicit ViaIfAsyncAwaiter(
: viaCoroutine_(detail::ViaCoroutine::createInline()), folly::Executor::KeepAlive<> executor,
awaiter_( Awaitable&& awaitable)
folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {}
template <typename Awaitable>
explicit ViaIfAsyncAwaiter(folly::Executor* executor, Awaitable&& awaitable)
: viaCoroutine_(detail::ViaCoroutine::create(executor)), : viaCoroutine_(detail::ViaCoroutine::create(executor)),
awaiter_( awaiter_(
folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {} folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {}
...@@ -249,11 +246,12 @@ template <typename Awaitable> ...@@ -249,11 +246,12 @@ template <typename Awaitable>
class ViaIfAsyncAwaitable { class ViaIfAsyncAwaitable {
public: public:
explicit ViaIfAsyncAwaitable( explicit ViaIfAsyncAwaitable(
folly::Executor* executor, folly::Executor::KeepAlive<> executor,
Awaitable&& Awaitable&&
awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>:: awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>::
value) value)
: executor_(executor), awaitable_(static_cast<Awaitable&&>(awaitable)) {} : executor_(std::move(executor)),
awaitable_(static_cast<Awaitable&&>(awaitable)) {}
template <typename Awaitable2> template <typename Awaitable2>
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>&& awaitable) friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>&& awaitable)
...@@ -274,7 +272,7 @@ class ViaIfAsyncAwaitable { ...@@ -274,7 +272,7 @@ class ViaIfAsyncAwaitable {
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&>>; -> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&>>;
private: private:
folly::Executor* executor_; folly::Executor::KeepAlive<> executor_;
Awaitable awaitable_; Awaitable awaitable_;
}; };
...@@ -316,18 +314,20 @@ template <typename SemiAwaitable> ...@@ -316,18 +314,20 @@ template <typename SemiAwaitable>
struct HasViaIfAsyncMethod< struct HasViaIfAsyncMethod<
SemiAwaitable, SemiAwaitable,
void_t<decltype(std::declval<SemiAwaitable>().viaIfAsync( void_t<decltype(std::declval<SemiAwaitable>().viaIfAsync(
std::declval<folly::Executor*>()))>> : std::true_type {}; std::declval<folly::Executor::KeepAlive<>>()))>> : std::true_type {};
namespace adl { namespace adl {
template <typename SemiAwaitable> template <typename SemiAwaitable>
auto co_viaIfAsync( auto co_viaIfAsync(
folly::Executor* executor, folly::Executor::KeepAlive<> executor,
SemiAwaitable&& SemiAwaitable&&
awaitable) noexcept(noexcept(static_cast<SemiAwaitable&&>(awaitable) awaitable) noexcept(noexcept(static_cast<SemiAwaitable&&>(awaitable)
.viaIfAsync(executor))) .viaIfAsync(std::move(executor))))
-> decltype(static_cast<SemiAwaitable&&>(awaitable).viaIfAsync(executor)) { -> decltype(static_cast<SemiAwaitable&&>(awaitable).viaIfAsync(
return static_cast<SemiAwaitable&&>(awaitable).viaIfAsync(executor); std::move(executor))) {
return static_cast<SemiAwaitable&&>(awaitable).viaIfAsync(
std::move(executor));
} }
template < template <
...@@ -335,23 +335,26 @@ template < ...@@ -335,23 +335,26 @@ template <
std::enable_if_t< std::enable_if_t<
is_awaitable_v<Awaitable> && !HasViaIfAsyncMethod<Awaitable>::value, is_awaitable_v<Awaitable> && !HasViaIfAsyncMethod<Awaitable>::value,
int> = 0> int> = 0>
auto co_viaIfAsync(folly::Executor* executor, Awaitable&& awaitable) auto co_viaIfAsync(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable)
-> ViaIfAsyncAwaitable<Awaitable> { -> ViaIfAsyncAwaitable<Awaitable> {
static_assert( static_assert(
folly::coro::is_awaitable_v<Awaitable>, folly::coro::is_awaitable_v<Awaitable>,
"co_viaIfAsync() argument 2 is not awaitable."); "co_viaIfAsync() argument 2 is not awaitable.");
return ViaIfAsyncAwaitable<Awaitable>{executor, return ViaIfAsyncAwaitable<Awaitable>{std::move(executor),
static_cast<Awaitable&&>(awaitable)}; static_cast<Awaitable&&>(awaitable)};
} }
struct ViaIfAsyncFunction { struct ViaIfAsyncFunction {
template <typename Awaitable> template <typename Awaitable>
auto operator()(folly::Executor* executor, Awaitable&& awaitable) const auto operator()(folly::Executor::KeepAlive<> executor, Awaitable&& awaitable)
noexcept(noexcept( const noexcept(noexcept(co_viaIfAsync(
co_viaIfAsync(executor, static_cast<Awaitable&&>(awaitable)))) std::move(executor),
-> decltype( static_cast<Awaitable&&>(awaitable))))
co_viaIfAsync(executor, static_cast<Awaitable&&>(awaitable))) { -> decltype(co_viaIfAsync(
return co_viaIfAsync(executor, static_cast<Awaitable&&>(awaitable)); std::move(executor),
static_cast<Awaitable&&>(awaitable))) {
return co_viaIfAsync(
std::move(executor), static_cast<Awaitable&&>(awaitable));
} }
}; };
...@@ -374,7 +377,7 @@ template <typename T> ...@@ -374,7 +377,7 @@ template <typename T>
struct is_semi_awaitable< struct is_semi_awaitable<
T, T,
void_t<decltype(folly::coro::co_viaIfAsync( void_t<decltype(folly::coro::co_viaIfAsync(
std::declval<folly::Executor*>(), std::declval<folly::Executor::KeepAlive<>>(),
std::declval<T>()))>> : std::true_type {}; std::declval<T>()))>> : std::true_type {};
template <typename T> template <typename T>
...@@ -382,7 +385,7 @@ constexpr bool is_semi_awaitable_v = is_semi_awaitable<T>::value; ...@@ -382,7 +385,7 @@ constexpr bool is_semi_awaitable_v = is_semi_awaitable<T>::value;
template <typename T> template <typename T>
using semi_await_result_t = await_result_t<decltype(folly::coro::co_viaIfAsync( using semi_await_result_t = await_result_t<decltype(folly::coro::co_viaIfAsync(
std::declval<folly::Executor*>(), std::declval<folly::Executor::KeepAlive<>>(),
std::declval<T>()))>; std::declval<T>()))>;
} // namespace coro } // namespace coro
......
...@@ -839,9 +839,9 @@ class SemiFuture : private futures::detail::FutureBase<T> { ...@@ -839,9 +839,9 @@ class SemiFuture : private futures::detail::FutureBase<T> {
// Customise the co_viaIfAsync() operator so that SemiFuture<T> can be // Customise the co_viaIfAsync() operator so that SemiFuture<T> can be
// directly awaited within a folly::coro::Task coroutine. // directly awaited within a folly::coro::Task coroutine.
friend Future<T> co_viaIfAsync( friend Future<T> co_viaIfAsync(
folly::Executor* executor, folly::Executor::KeepAlive<> executor,
SemiFuture<T>&& future) noexcept { SemiFuture<T>&& future) noexcept {
return std::move(future).via(executor); return std::move(future).via(std::move(executor));
} }
#endif #endif
...@@ -1914,9 +1914,9 @@ class Future : private futures::detail::FutureBase<T> { ...@@ -1914,9 +1914,9 @@ class Future : private futures::detail::FutureBase<T> {
// Overload needed to customise behaviour of awaiting a Future<T> // Overload needed to customise behaviour of awaiting a Future<T>
// inside a folly::coro::Task coroutine. // inside a folly::coro::Task coroutine.
friend Future<T> co_viaIfAsync( friend Future<T> co_viaIfAsync(
folly::Executor* executor, folly::Executor::KeepAlive<> executor,
Future<T>&& future) noexcept { Future<T>&& future) noexcept {
return std::move(future).via(executor); return std::move(future).via(std::move(executor));
} }
#endif #endif
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment