Commit d5f9c13d authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

folly::coro::Task now preserves RequestContext across suspend/resume points

Summary:
The folly::coro::Task coroutine type now captures the current RequestContext when the coroutine suspends and restores it when it later resumes.

This means that folly::coro::Task can now be used safely with RequestContext and RequestContextScopeGuard.

Reviewed By: andriigrynenko

Differential Revision: D9973428

fbshipit-source-id: 41ea54baf334f0af3dd46ceb32465580f06fb37e
parent 73044d85
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <folly/experimental/coro/ViaIfAsync.h> #include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/experimental/coro/detail/InlineTask.h> #include <folly/experimental/coro/detail/InlineTask.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
#include <folly/io/async/Request.h>
namespace folly { namespace folly {
namespace coro { namespace coro {
...@@ -239,7 +240,11 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -239,7 +240,11 @@ class FOLLY_NODISCARD TaskWithExecutor {
DCHECK(promise.executor_ != nullptr); DCHECK(promise.executor_ != nullptr);
promise.continuation_ = continuation; promise.continuation_ = continuation;
promise.executor_->add(coro_); promise.executor_->add(
[coro = coro_, ctx = RequestContext::saveContext()]() mutable {
RequestContextScopeGuard contextScope{std::move(ctx)};
coro.resume();
});
} }
decltype(auto) await_resume() { decltype(auto) await_resume() {
...@@ -300,6 +305,11 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -300,6 +305,11 @@ class FOLLY_NODISCARD TaskWithExecutor {
/// 'co_await expr' expression into /// 'co_await expr' expression into
/// `co_await co_viaIfAsync(boundExecutor, expr)' to ensure that the coroutine /// `co_await co_viaIfAsync(boundExecutor, expr)' to ensure that the coroutine
/// always resumes on the executor. /// always resumes on the executor.
///
/// The Task coroutine is RequestContext-aware and will capture the
/// current RequestContext at the time the coroutine function is called and
/// will save/restore the current RequestContext whenever the coroutine
/// suspends and resumes at a co_await expression.
template <typename T> template <typename T>
class FOLLY_NODISCARD Task { class FOLLY_NODISCARD Task {
public: public:
...@@ -393,7 +403,9 @@ auto detail::TaskPromiseBase::await_transform(Task<T>&& t) noexcept { ...@@ -393,7 +403,9 @@ auto detail::TaskPromiseBase::await_transform(Task<T>&& t) noexcept {
handle_t coro_; handle_t coro_;
}; };
// Child task inherits the awaiting task's executor
t.coro_.promise().executor_ = executor_; t.coro_.promise().executor_ = executor_;
return Awaiter{std::exchange(t.coro_, {})}; return Awaiter{std::exchange(t.coro_, {})};
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <folly/Executor.h> #include <folly/Executor.h>
#include <folly/experimental/coro/Traits.h> #include <folly/experimental/coro/Traits.h>
#include <folly/io/async/Request.h>
#include <glog/logging.h> #include <glog/logging.h>
...@@ -56,7 +57,14 @@ class ViaCoroutine { ...@@ -56,7 +57,14 @@ class ViaCoroutine {
std::experimental::coroutine_handle<promise_type> coro) noexcept { std::experimental::coroutine_handle<promise_type> coro) noexcept {
// Schedule resumption of the coroutine on the executor. // Schedule resumption of the coroutine on the executor.
auto& promise = coro.promise(); auto& promise = coro.promise();
promise.executor_->add(promise.continuation_); if (!promise.context_) {
promise.context_ = RequestContext::saveContext();
}
promise.executor_->add([&promise]() noexcept {
RequestContextScopeGuard contextScope{std::move(promise.context_)};
promise.continuation_.resume();
});
} }
void await_resume() noexcept {} void await_resume() noexcept {}
}; };
...@@ -76,9 +84,14 @@ class ViaCoroutine { ...@@ -76,9 +84,14 @@ class ViaCoroutine {
continuation_ = continuation; continuation_ = continuation;
} }
void setContext(std::shared_ptr<RequestContext> context) noexcept {
context_ = std::move(context);
}
private: private:
folly::Executor* executor_; folly::Executor* executor_;
std::experimental::coroutine_handle<> continuation_; std::experimental::coroutine_handle<> continuation_;
std::shared_ptr<RequestContext> context_;
}; };
ViaCoroutine(ViaCoroutine&& other) noexcept ViaCoroutine(ViaCoroutine&& other) noexcept
...@@ -107,6 +120,12 @@ class ViaCoroutine { ...@@ -107,6 +120,12 @@ class ViaCoroutine {
} }
} }
std::experimental::coroutine_handle<> getWrappedCoroutineWithSavedContext(
std::experimental::coroutine_handle<> continuation) noexcept {
coro_.promise().setContext(RequestContext::saveContext());
return getWrappedCoroutine(continuation);
}
void destroy() { void destroy() {
if (coro_) { if (coro_) {
std::exchange(coro_, {}).destroy(); std::exchange(coro_, {}).destroy();
...@@ -135,6 +154,10 @@ class ViaCoroutine { ...@@ -135,6 +154,10 @@ class ViaCoroutine {
template <typename Awaiter> template <typename Awaiter>
class ViaIfAsyncAwaiter { class ViaIfAsyncAwaiter {
using await_suspend_result_t =
decltype(std::declval<Awaiter&>().await_suspend(
std::declval<std::experimental::coroutine_handle<>>()));
public: public:
static_assert( static_assert(
folly::coro::is_awaiter_v<Awaiter>, folly::coro::is_awaiter_v<Awaiter>,
...@@ -154,16 +177,62 @@ class ViaIfAsyncAwaiter { ...@@ -154,16 +177,62 @@ class ViaIfAsyncAwaiter {
bool await_ready() noexcept( bool await_ready() noexcept(
noexcept(std::declval<Awaiter&>().await_ready())) { noexcept(std::declval<Awaiter&>().await_ready())) {
DCHECK(true);
return awaiter_.await_ready(); return awaiter_.await_ready();
} }
// NOTE: We are using a heuristic here to determine when is the correct
// time to capture the RequestContext. We want to capture the context just
// before the coroutine suspends and execution is returned to the executor.
//
// In cases where we are awaiting another coroutine and symmetrically
// transferring execution to another coroutine we are not yet returning
// execution to the executor so we want to defer capturing the context until
// the ViaCoroutine is resumed and suspends in final_suspend() before
// scheduling the resumption on the executor.
//
// In cases where the awaitable may suspend without transferring execution
// to another coroutine and will therefore return back to the executor we
// want to capture the execution context before calling into the wrapped
// awaitable's await_suspend() method (since it's await_suspend() method
// might schedule resumption on another thread and could resume and destroy
// the ViaCoroutine before the await_suspend() method returns).
//
// The heuristic is that if await_suspend() returns a coroutine_handle
// then we assume it's the first case. Otherwise if await_suspend() returns
// void/bool then we assume it's the second case.
//
// This heuristic isn't perfect since a coroutine_handle-returning
// await_suspend() method could return noop_coroutine() in which case we
// could fail to capture the current context. Awaitable types that do this
// would need to provide a custom implementation of co_viaIfAsync() that
// correctly captures the RequestContext to get correct behaviour in this
// case.
template <
typename Result = await_suspend_result_t,
std::enable_if_t<
folly::coro::detail::_is_coroutine_handle<Result>::value,
int> = 0>
auto auto
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept( await_suspend(std::experimental::coroutine_handle<> continuation) noexcept(
noexcept(std::declval<Awaiter&>().await_suspend(continuation))) { noexcept(awaiter_.await_suspend(continuation))) -> Result {
return awaiter_.await_suspend( return awaiter_.await_suspend(
viaCoroutine_.getWrappedCoroutine(continuation)); viaCoroutine_.getWrappedCoroutine(continuation));
} }
template <
typename Result = await_suspend_result_t,
std::enable_if_t<
!folly::coro::detail::_is_coroutine_handle<Result>::value,
int> = 0>
auto
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept(
noexcept(awaiter_.await_suspend(continuation))) -> Result {
return awaiter_.await_suspend(
viaCoroutine_.getWrappedCoroutineWithSavedContext(continuation));
}
decltype(auto) await_resume() noexcept( decltype(auto) await_resume() noexcept(
noexcept(std::declval<Awaiter&>().await_resume())) { noexcept(std::declval<Awaiter&>().await_resume())) {
viaCoroutine_.destroy(); viaCoroutine_.destroy();
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#if FOLLY_HAS_COROUTINES #if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/BlockingWait.h> #include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/detail/InlineTask.h> #include <folly/experimental/coro/detail/InlineTask.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
...@@ -206,7 +205,7 @@ struct MyException : std::exception {}; ...@@ -206,7 +205,7 @@ struct MyException : std::exception {};
TEST(InlineTask, ExceptionsPropagateFromVoidTask) { TEST(InlineTask, ExceptionsPropagateFromVoidTask) {
auto f = []() -> InlineTask<void> { auto f = []() -> InlineTask<void> {
co_await folly::coro::Baton{true}; co_await std::experimental::suspend_never{};
throw MyException{}; throw MyException{};
}; };
EXPECT_THROW(folly::coro::blockingWait(f()), MyException); EXPECT_THROW(folly::coro::blockingWait(f()), MyException);
...@@ -214,7 +213,7 @@ TEST(InlineTask, ExceptionsPropagateFromVoidTask) { ...@@ -214,7 +213,7 @@ TEST(InlineTask, ExceptionsPropagateFromVoidTask) {
TEST(InlineTask, ExceptionsPropagateFromValueTask) { TEST(InlineTask, ExceptionsPropagateFromValueTask) {
auto f = []() -> InlineTask<int> { auto f = []() -> InlineTask<int> {
co_await folly::coro::Baton{true}; co_await std::experimental::suspend_never{};
throw MyException{}; throw MyException{};
}; };
EXPECT_THROW(folly::coro::blockingWait(f()), MyException); EXPECT_THROW(folly::coro::blockingWait(f()), MyException);
...@@ -222,7 +221,7 @@ TEST(InlineTask, ExceptionsPropagateFromValueTask) { ...@@ -222,7 +221,7 @@ TEST(InlineTask, ExceptionsPropagateFromValueTask) {
TEST(InlineTask, ExceptionsPropagateFromRefTask) { TEST(InlineTask, ExceptionsPropagateFromRefTask) {
auto f = []() -> InlineTask<int&> { auto f = []() -> InlineTask<int&> {
co_await folly::coro::Baton{true}; co_await std::experimental::suspend_never{};
throw MyException{}; throw MyException{};
}; };
EXPECT_THROW(folly::coro::blockingWait(f()), MyException); EXPECT_THROW(folly::coro::blockingWait(f()), MyException);
......
/*
* Copyright 2017-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/executors/InlineExecutor.h>
#include <folly/executors/ManualExecutor.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/Mutex.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/detail/InlineTask.h>
#include <folly/portability/GTest.h>
using namespace folly;
namespace {
const RequestToken testToken1("corotest1");
const RequestToken testToken2("corotest2");
class TestRequestData : public RequestData {
public:
explicit TestRequestData(std::string key) noexcept : key_(std::move(key)) {}
bool hasCallback() override {
return false;
}
const std::string& key() const noexcept {
return key_;
}
private:
std::string key_;
};
} // namespace
static coro::Task<void> childRequest(coro::Mutex& m, coro::Baton& b) {
ShallowCopyRequestContextScopeGuard requestScope;
auto* parentContext = dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken1));
EXPECT_TRUE(parentContext != nullptr);
auto childKey = parentContext->key() + ".child";
RequestContext::get()->setContextData(
testToken2, std::make_unique<TestRequestData>(childKey));
auto* childContext = dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken2));
CHECK(childContext != nullptr);
{
auto lock = co_await m.co_scoped_lock();
CHECK_EQ(
parentContext,
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken1)));
CHECK_EQ(
childContext,
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken2)));
co_await b;
CHECK_EQ(
parentContext,
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken1)));
CHECK_EQ(
childContext,
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken2)));
}
CHECK_EQ(
parentContext,
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken1)));
CHECK_EQ(
childContext,
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken2)));
}
static coro::Task<void> parentRequest(int id) {
ShallowCopyRequestContextScopeGuard requestScope;
// Should have captured the value at the time the coroutine was co_awaited
// rather than at the time the coroutine was called.
auto* globalData = dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData("global"));
CHECK(globalData != nullptr);
CHECK_EQ("other value", globalData->key());
std::string key = folly::to<std::string>("request", id);
RequestContext::get()->setContextData(
testToken1, std::make_unique<TestRequestData>(key));
auto* contextData = RequestContext::get()->getContextData(testToken1);
CHECK(contextData != nullptr);
coro::Mutex mutex;
coro::Baton baton1;
coro::Baton baton2;
auto fut1 = childRequest(mutex, baton1)
.scheduleOn(co_await coro::co_current_executor)
.start();
auto fut2 = childRequest(mutex, baton1)
.scheduleOn(co_await coro::co_current_executor)
.start();
CHECK_EQ(contextData, RequestContext::get()->getContextData(testToken1));
baton1.post();
baton2.post();
(void)co_await std::move(fut1);
CHECK_EQ(contextData, RequestContext::get()->getContextData(testToken1));
// Check that context from child operation doesn't leak into this coroutine.
CHECK(RequestContext::get()->getContextData(testToken2) == nullptr);
(void)co_await std::move(fut2);
// Check that context from child operation doesn't leak into this coroutine.
CHECK(RequestContext::get()->getContextData(testToken2) == nullptr);
}
TEST(Task, RequestContextIsPreservedAcrossSuspendResume) {
ManualExecutor executor;
RequestContextScopeGuard requestScope;
RequestContext::get()->setContextData(
"global", std::make_unique<TestRequestData>("global value"));
// Context should be captured at coroutine co_await time and not at
// call time.
auto task1 = parentRequest(1).scheduleOn(&executor);
auto task2 = parentRequest(2).scheduleOn(&executor);
{
RequestContextScopeGuard nestedRequestScope;
RequestContext::get()->setContextData(
"global", std::make_unique<TestRequestData>("other value"));
// Start execution of the tasks.
auto fut1 = std::move(task1).start();
auto fut2 = std::move(task2).start();
// Check that the contexts set by starting the tasks don't bleed out
// to the caller.
CHECK(RequestContext::get()->getContextData(testToken1) == nullptr);
CHECK(RequestContext::get()->getContextData(testToken2) == nullptr);
CHECK_EQ(
"other value",
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData("global"))
->key());
executor.drain();
CHECK(fut1.isReady());
CHECK(fut2.isReady());
// Check that the contexts set by the coroutines executing on the executor
// do not leak out to the caller.
CHECK(RequestContext::get()->getContextData(testToken1) == nullptr);
CHECK(RequestContext::get()->getContextData(testToken2) == nullptr);
CHECK_EQ(
"other value",
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData("global"))
->key());
}
}
TEST(Task, ContextPreservedAcrossMutexLock) {
folly::coro::Mutex mutex;
auto handleRequest =
[&](folly::coro::Baton& event) -> folly::coro::Task<void> {
RequestContextScopeGuard requestScope;
RequestData* contextDataPtr = nullptr;
{
auto contextData = std::make_unique<TestRequestData>("some value");
contextDataPtr = contextData.get();
RequestContext::get()->setContextData(
"mutex_test", std::move(contextData));
}
auto lock = co_await mutex.co_scoped_lock();
// Check that the request context was preserved across mutex lock
// acquisition.
CHECK_EQ(
RequestContext::get()->getContextData("mutex_test"), contextDataPtr);
co_await event;
// Check that request context was preserved across baton wait.
CHECK_EQ(
RequestContext::get()->getContextData("mutex_test"), contextDataPtr);
};
folly::ManualExecutor manualExecutor;
folly::coro::Baton event1;
folly::coro::Baton event2;
auto t1 = handleRequest(event1).scheduleOn(&manualExecutor).start();
auto t2 = handleRequest(event2).scheduleOn(&manualExecutor).start();
manualExecutor.drain();
event1.post();
manualExecutor.drain();
event2.post();
manualExecutor.drain();
EXPECT_TRUE(t1.isReady());
EXPECT_TRUE(t2.isReady());
EXPECT_FALSE(t1.hasException());
EXPECT_FALSE(t2.hasException());
}
TEST(Task, RequestContextSideEffectsArePreserved) {
auto f =
[&](folly::coro::Baton& baton) -> folly::coro::detail::InlineTask<void> {
RequestContext::create();
RequestContext::get()->setContextData(
testToken1, std::make_unique<TestRequestData>("test"));
EXPECT_NE(RequestContext::get()->getContextData(testToken1), nullptr);
// HACK: Need to use co_viaIfAsync() to ensure request context is preserved
// across suspend-point.
co_await co_viaIfAsync(&folly::InlineExecutor::instance(), baton);
EXPECT_NE(RequestContext::get()->getContextData(testToken1), nullptr);
co_return;
};
auto g = [&](folly::coro::Baton& baton) -> folly::coro::Task<void> {
EXPECT_EQ(RequestContext::get()->getContextData(testToken1), nullptr);
co_await f(baton);
EXPECT_NE(RequestContext::get()->getContextData(testToken1), nullptr);
EXPECT_EQ(
dynamic_cast<TestRequestData*>(
RequestContext::get()->getContextData(testToken1))
->key(),
"test");
};
folly::ManualExecutor executor;
folly::coro::Baton baton;
auto t = g(baton).scheduleOn(&executor).start();
executor.drain();
baton.post();
executor.drain();
EXPECT_TRUE(t.isReady());
EXPECT_FALSE(t.hasException());
}
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment