Commit 50f425bb authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Add CancellationToken support to AsyncGenerator

Summary:
An AsyncGenerator coroutine now has an implicit associate CancellationToken that is inherited from the calling context of each `co_await gen.next()` call.

This means that generators that correspond to long-running subscriptions now have an in-built channel that allows them to cancel the subscription.

This is also useful for algorithms that introduce concurrency in processing of stream data. eg. The `buffer()` or `merge()` algorithms.

Reviewed By: kirkshoop

Differential Revision: D16833864

fbshipit-source-id: be8faf1ec438c4248091e02d6a264d3760e2d73d
parent 178fe966
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
*/ */
#pragma once #pragma once
#include <folly/CancellationToken.h>
#include <folly/Traits.h> #include <folly/Traits.h>
#include <folly/experimental/coro/CurrentExecutor.h> #include <folly/experimental/coro/CurrentExecutor.h>
#include <folly/experimental/coro/Utils.h> #include <folly/experimental/coro/Utils.h>
#include <folly/experimental/coro/ViaIfAsync.h> #include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/experimental/coro/WithCancellation.h>
#include <folly/experimental/coro/detail/ManualLifetime.h> #include <folly/experimental/coro/detail/ManualLifetime.h>
#include <glog/logging.h> #include <glog/logging.h>
...@@ -72,7 +74,7 @@ class AsyncGeneratorPromise { ...@@ -72,7 +74,7 @@ class AsyncGeneratorPromise {
DCHECK(!hasValue_); DCHECK(!hasValue_);
value_.construct(static_cast<Reference&&>(value)); value_.construct(static_cast<Reference&&>(value));
hasValue_ = true; hasValue_ = true;
executor_.reset(); clearContext();
return YieldAwaiter{}; return YieldAwaiter{};
} }
...@@ -92,7 +94,7 @@ class AsyncGeneratorPromise { ...@@ -92,7 +94,7 @@ class AsyncGeneratorPromise {
DCHECK(!hasValue_); DCHECK(!hasValue_);
value_.construct(static_cast<U&&>(value)); value_.construct(static_cast<U&&>(value));
hasValue_ = true; hasValue_ = true;
executor_.reset(); clearContext();
return {}; return {};
} }
...@@ -108,13 +110,28 @@ class AsyncGeneratorPromise { ...@@ -108,13 +110,28 @@ class AsyncGeneratorPromise {
template <typename U> template <typename U>
auto await_transform(U&& value) { auto await_transform(U&& value) {
return folly::coro::co_viaIfAsync( return folly::coro::co_viaIfAsync(
executor_.get_alias(), static_cast<U&&>(value)); executor_.get_alias(),
folly::coro::co_withCancellation(
cancelToken_, static_cast<U&&>(value)));
} }
auto await_transform(folly::coro::co_current_executor_t) noexcept { auto await_transform(folly::coro::co_current_executor_t) noexcept {
return AwaitableReady<folly::Executor*>{executor_.get()}; return AwaitableReady<folly::Executor*>{executor_.get()};
} }
auto await_transform(folly::coro::co_current_cancellation_token_t) noexcept {
return AwaitableReady<folly::CancellationToken>{cancelToken_};
}
void setCancellationToken(folly::CancellationToken cancelToken) noexcept {
// Only keep the first cancellation token.
// ie. the inner-most cancellation scope of the consumer's calling context.
if (!hasCancelTokenOverride_) {
cancelToken_ = std::move(cancelToken);
hasCancelTokenOverride_ = true;
}
}
void setExecutor(folly::Executor::KeepAlive<> executor) noexcept { void setExecutor(folly::Executor::KeepAlive<> executor) noexcept {
executor_ = std::move(executor); executor_ = std::move(executor);
} }
...@@ -148,11 +165,19 @@ class AsyncGeneratorPromise { ...@@ -148,11 +165,19 @@ class AsyncGeneratorPromise {
} }
private: private:
void clearContext() noexcept {
executor_ = {};
cancelToken_ = {};
hasCancelTokenOverride_ = false;
}
std::experimental::coroutine_handle<> continuation_; std::experimental::coroutine_handle<> continuation_;
folly::Executor::KeepAlive<> executor_; folly::Executor::KeepAlive<> executor_;
folly::CancellationToken cancelToken_;
std::exception_ptr exception_; std::exception_ptr exception_;
ManualLifetime<Reference> value_; ManualLifetime<Reference> value_;
bool hasValue_ = false; bool hasValue_ = false;
bool hasCancelTokenOverride_ = false;
}; };
} // namespace detail } // namespace detail
...@@ -422,8 +447,17 @@ class FOLLY_NODISCARD AsyncGenerator { ...@@ -422,8 +447,17 @@ class FOLLY_NODISCARD AsyncGenerator {
return NextAwaitable{coro_}; return NextAwaitable{coro_};
} }
friend NextSemiAwaitable co_withCancellation(
CancellationToken cancelToken,
NextSemiAwaitable&& awaitable) {
if (awaitable.coro_) {
awaitable.coro_.promise().setCancellationToken(std::move(cancelToken));
}
return NextSemiAwaitable{std::exchange(awaitable.coro_, {})};
}
private: private:
friend AsyncGenerator; //<Reference, Value>; friend AsyncGenerator;
explicit NextSemiAwaitable(handle_t coro) noexcept : coro_(coro) {} explicit NextSemiAwaitable(handle_t coro) noexcept : coro_(coro) {}
......
...@@ -19,11 +19,14 @@ ...@@ -19,11 +19,14 @@
#if FOLLY_HAS_COROUTINES #if FOLLY_HAS_COROUTINES
#include <folly/ScopeGuard.h> #include <folly/ScopeGuard.h>
#include <folly/Traits.h>
#include <folly/experimental/coro/AsyncGenerator.h> #include <folly/experimental/coro/AsyncGenerator.h>
#include <folly/experimental/coro/Baton.h> #include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/BlockingWait.h> #include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/experimental/coro/Sleep.h> #include <folly/experimental/coro/Sleep.h>
#include <folly/experimental/coro/Task.h> #include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/WithCancellation.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
...@@ -403,4 +406,40 @@ TEST(AsyncGenerator, InvokeLambda) { ...@@ -403,4 +406,40 @@ TEST(AsyncGenerator, InvokeLambda) {
}()); }());
} }
template <typename Ref, typename Value = folly::remove_cvref_t<Ref>>
folly::coro::AsyncGenerator<Ref, Value> neverStream() {
folly::coro::Baton baton;
folly::CancellationCallback cb{
co_await folly::coro::co_current_cancellation_token,
[&] { baton.post(); }};
co_await baton;
}
TEST(AsyncGenerator, CancellationTokenPropagatesFromConsumer) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
folly::CancellationSource cancelSource;
bool suspended = false;
bool done = false;
co_await folly::coro::collectAll(
folly::coro::co_withCancellation(
cancelSource.getToken(),
[&]() -> folly::coro::Task<void> {
auto stream = neverStream<int>();
suspended = true;
auto result = co_await stream.next();
CHECK(!result.has_value());
done = true;
}()),
[&]() -> folly::coro::Task<void> {
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
CHECK(suspended);
CHECK(!done);
cancelSource.requestCancellation();
}());
CHECK(done);
}());
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment