Commit fefc6887 authored by Shai Szulanski's avatar Shai Szulanski Committed by Facebook GitHub Bot

folly::coro::Promise

Summary:
A lightweight replacement for folly::Promise for use with coroutines.
The main selling point is cancellation-awareness. It also has a smaller memory footprint (similar CPU cost) and can be `blockingWait`ed.

Differential Revision: D31561034

fbshipit-source-id: d803e02cfd7753ee1f532c8f696cf7fbf672938b
parent 8cf4f930
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <atomic>
#include <utility>
#include <folly/CancellationToken.h>
#include <folly/Try.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/Coroutine.h>
#include <folly/futures/Promise.h>
#if FOLLY_HAS_COROUTINES
namespace folly::coro {
template <typename T>
class Promise;
template <typename T>
class Future;
// Creates promise and associated unfulfilled future
template <typename T>
std::pair<Promise<T>, Future<T>> makePromiseContract();
// Creates fulfilled future
template <typename T>
Future<remove_cvref_t<T>> makeFuture(T&&);
template <typename T>
Future<T> makeFuture(exception_wrapper&&);
Future<void> makeFuture();
namespace detail {
template <typename T>
struct PromiseState {
PromiseState() = default;
Try<T> result;
// Must be exchanged to true before setting result
std::atomic<bool> fulfilled{false};
// Must be posted after setting result
coro::Baton ready;
};
} // namespace detail
template <typename T>
class Promise {
public:
Promise(Promise&& other) noexcept
: ct_(std::move(other.ct_)),
state_(std::exchange(other.state_, nullptr)) {}
Promise& operator=(Promise&& other) noexcept {
if (this != &other && state_ && !state_->fulfilled) {
setException(BrokenPromise{pretty_name<T>()});
}
ct_ = std::move(other.ct_);
state_ = std::exchange(other.state_, nullptr);
return *this;
}
Promise(const Promise&) = delete;
Promise& operator=(const Promise&) = delete;
~Promise() {
if (state_ && !state_->fulfilled) {
setException(BrokenPromise{pretty_name<T>()});
}
}
template <typename U = T, typename = std::enable_if_t<!std::is_void_v<U>>>
void setValue(U&& value) {
DCHECK(state_);
if (!state_->fulfilled.exchange(true, std::memory_order_relaxed)) {
state_->result.emplace(std::forward<U>(value));
state_->ready.post();
}
}
template <typename U = T, typename = std::enable_if_t<std::is_void_v<U>>>
void setValue() {
DCHECK(state_);
if (!state_->fulfilled.exchange(true, std::memory_order_relaxed)) {
state_->ready.post();
}
}
void setException(exception_wrapper&& ex) {
DCHECK(state_);
if (!state_->fulfilled.exchange(true, std::memory_order_relaxed)) {
state_->result.emplaceException(std::move(ex));
state_->ready.post();
}
}
void setResult(Try<T>&& result) {
DCHECK(state_);
if (!state_->fulfilled.exchange(true, std::memory_order_relaxed)) {
state_->result = std::move(result);
state_->ready.post();
}
}
const CancellationToken& getCancellationToken() const { return ct_; }
private:
Promise(CancellationToken ct, detail::PromiseState<T>& state)
: ct_(std::move(ct)), state_(&state) {}
CancellationToken ct_;
detail::PromiseState<T>* state_;
friend std::pair<Promise<T>, Future<T>> makePromiseContract<T>();
};
template <typename T>
class Future {
public:
Future(Future&&) noexcept = default;
Future& operator=(Future&&) noexcept = default;
Future(const Future&) = delete;
Future& operator=(const Future&) = delete;
class WaitOperation : private Baton::WaitOperation {
public:
explicit WaitOperation(Future& future) noexcept
: Baton::WaitOperation(future.state_.ready),
future_(future),
cb_(std::move(future.ct_), [&] { future_.cancel(); }) {}
using Baton::WaitOperation::await_ready;
using Baton::WaitOperation::await_suspend;
T await_resume() {
if constexpr (!std::is_void_v<T>) {
return std::move(future_.state_.result.value());
}
}
folly::Try<T> await_resume_try() {
return std::move(future_.state_.result);
}
private:
Future& future_;
CancellationCallback cb_;
};
[[nodiscard]] WaitOperation operator co_await() && noexcept {
return WaitOperation{*this};
}
bool isReady() const noexcept { return state_.ready.ready(); }
friend Future co_withCancellation(
folly::CancellationToken ct, Future&& future) noexcept {
if (!std::exchange(future.hasCancelTokenOverride_, true)) {
future.ct_ = std::move(ct);
}
return std::move(future);
}
private:
Future(CancellationSource cs, detail::PromiseState<T>& state)
: cs_(std::move(cs)), state_(state) {}
void cancel() {
if (!state_.fulfilled.exchange(true, std::memory_order_relaxed)) {
cs_.requestCancellation();
state_.result.emplaceException(OperationCancelled{});
state_.ready.post();
}
}
CancellationSource cs_;
detail::PromiseState<T>& state_;
// The token inherited when the future is awaited
CancellationToken ct_;
bool hasCancelTokenOverride_{false};
friend std::pair<Promise<T>, Future<T>> makePromiseContract<T>();
};
template <typename T>
std::pair<Promise<T>, Future<T>> makePromiseContract() {
auto [cs, data] = CancellationSource::create(
folly::detail::WithDataTag<detail::PromiseState<T>>{});
return {
Promise<T>{cs.getToken(), std::get<0>(*data)},
Future<T>{std::move(cs), std::get<0>(*data)}};
}
template <typename T>
Future<remove_cvref_t<T>> makeFuture(T&& t) {
auto [promise, future] = makePromiseContract<remove_cvref_t<T>>();
promise.setValue(std::forward<T>(t));
return std::move(future);
}
template <typename T>
Future<T> makeFuture(exception_wrapper&& ex) {
auto [promise, future] = makePromiseContract<T>();
promise.setException(std::move(ex));
return std::move(future);
}
inline Future<void> makeFuture() {
auto [promise, future] = makePromiseContract<void>();
promise.setValue();
return std::move(future);
}
} // namespace folly::coro
#endif
...@@ -303,7 +303,7 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -303,7 +303,7 @@ class FOLLY_NODISCARD TaskWithExecutor {
// Start execution of this task eagerly and return a folly::SemiFuture<T> // Start execution of this task eagerly and return a folly::SemiFuture<T>
// that will complete with the result. // that will complete with the result.
FOLLY_NOINLINE SemiFuture<lift_unit_t<StorageType>> start() && { FOLLY_NOINLINE SemiFuture<lift_unit_t<StorageType>> start() && {
Promise<lift_unit_t<StorageType>> p; folly::Promise<lift_unit_t<StorageType>> p;
auto sf = p.getSemiFuture(); auto sf = p.getSemiFuture();
...@@ -343,7 +343,7 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -343,7 +343,7 @@ class FOLLY_NODISCARD TaskWithExecutor {
// assuming the current thread is already on the associated executor, // assuming the current thread is already on the associated executor,
// and return a folly::SemiFuture<T> that will complete with the result. // and return a folly::SemiFuture<T> that will complete with the result.
FOLLY_NOINLINE SemiFuture<lift_unit_t<StorageType>> startInlineUnsafe() && { FOLLY_NOINLINE SemiFuture<lift_unit_t<StorageType>> startInlineUnsafe() && {
Promise<lift_unit_t<StorageType>> p; folly::Promise<lift_unit_t<StorageType>> p;
auto sf = p.getSemiFuture(); auto sf = p.getSemiFuture();
...@@ -626,7 +626,7 @@ class FOLLY_NODISCARD Task { ...@@ -626,7 +626,7 @@ class FOLLY_NODISCARD Task {
[task = std::move(*this), [task = std::move(*this),
returnAddress = FOLLY_ASYNC_STACK_RETURN_ADDRESS()]( returnAddress = FOLLY_ASYNC_STACK_RETURN_ADDRESS()](
const Executor::KeepAlive<>& executor, Try<Unit>&&) mutable { const Executor::KeepAlive<>& executor, Try<Unit>&&) mutable {
Promise<lift_unit_t<StorageType>> p; folly::Promise<lift_unit_t<StorageType>> p;
auto sf = p.getSemiFuture(); auto sf = p.getSemiFuture();
......
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Benchmark.h>
#include <folly/Portability.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/experimental/coro/FutureUtil.h>
#include <folly/experimental/coro/Promise.h>
#include <folly/experimental/coro/Task.h>
#include <folly/futures/Future.h>
#if FOLLY_HAS_COROUTINES
void resetMallocStats() {
BENCHMARK_SUSPEND {
static uint64_t epoch = 0;
++epoch;
size_t sz = sizeof(epoch);
mallctl("epoch", &epoch, &sz, &epoch, sz);
};
}
void setMallocStats(folly::UserCounters& counters) {
BENCHMARK_SUSPEND {
size_t allocated = 0;
size_t sz = sizeof(allocated);
mallctl("stats.allocated", &allocated, &sz, nullptr, 0);
counters["allocated"] = allocated;
};
}
BENCHMARK_COUNTERS(CoroFutureImmediateUnwrapped, counters, iters) {
resetMallocStats();
for (std::size_t i = 0; i < iters; ++i) {
auto [promise, future] = folly::coro::makePromiseContract<int>();
promise.setValue(42);
folly::coro::blockingWait(std::move(future));
}
setMallocStats(counters);
}
// You can't directly blockingWait a SemiFuture (it deadlocks) so there's no
// comparison for this one
BENCHMARK_COUNTERS(CoroFutureImmediate, counters, iters) {
resetMallocStats();
for (std::size_t i = 0; i < iters; ++i) {
auto [promise, future] = folly::coro::makePromiseContract<int>();
promise.setValue(42);
folly::coro::blockingWait(folly::coro::toTask(std::move(future)));
}
setMallocStats(counters);
}
BENCHMARK_COUNTERS(FuturesFutureImmediate, counters, iters) {
resetMallocStats();
for (std::size_t i = 0; i < iters; ++i) {
auto [promise, future] = folly::makePromiseContract<int>();
promise.setValue(42);
folly::coro::blockingWait(folly::coro::toTask(std::move(future)));
}
setMallocStats(counters);
}
BENCHMARK_COUNTERS(CoroFutureSuspend, counters, iters) {
resetMallocStats();
for (std::size_t i = 0; i < iters; ++i) {
auto [promise, future] = folly::coro::makePromiseContract<int>();
auto waiter = [](auto future) -> folly::coro::Task<int> {
co_return co_await std::move(future);
}(std::move(future));
auto fulfiller = [](auto promise) -> folly::coro::Task<> {
promise.setValue(42);
co_return;
}(std::move(promise));
folly::coro::blockingWait(folly::coro::collectAll(
co_awaitTry(std::move(waiter)), std::move(fulfiller)));
}
setMallocStats(counters);
}
BENCHMARK_COUNTERS(FuturesFutureSuspend, counters, iters) {
resetMallocStats();
for (std::size_t i = 0; i < iters; ++i) {
auto [promise, future] = folly::makePromiseContract<int>();
auto waiter = [](auto future) -> folly::coro::Task<int> {
co_return co_await std::move(future);
}(std::move(future));
auto fulfiller = [](auto promise) -> folly::coro::Task<> {
promise.setValue(42);
co_return;
}(std::move(promise));
folly::coro::blockingWait(folly::coro::collectAll(
co_awaitTry(std::move(waiter)), std::move(fulfiller)));
}
setMallocStats(counters);
}
#endif
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
folly::runBenchmarks();
return 0;
}
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/experimental/coro/Promise.h>
#include <folly/Portability.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/experimental/coro/GtestHelpers.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/WithCancellation.h>
#include <folly/portability/GTest.h>
#if FOLLY_HAS_COROUTINES
using namespace folly;
using namespace ::testing;
CO_TEST(PromiseTest, ImmediateValue) {
auto [promise, future] = coro::makePromiseContract<int>();
promise.setValue(42);
EXPECT_EQ(co_await std::move(future), 42);
}
CO_TEST(PromiseTest, ImmediateTry) {
auto [promise, future] = coro::makePromiseContract<int>();
promise.setResult(folly::Try(42));
auto res = co_await co_awaitTry(std::move(future));
EXPECT_EQ(res.value(), 42);
}
CO_TEST(PromiseTest, ImmediateException) {
auto [promise, future] = coro::makePromiseContract<int>();
promise.setException(std::runtime_error(""));
auto res = co_await co_awaitTry(std::move(future));
EXPECT_TRUE(res.hasException<std::runtime_error>());
}
CO_TEST(PromiseTest, SuspendValue) {
auto [promise, future] = coro::makePromiseContract<int>();
auto waiter = [](auto future) -> coro::Task<int> {
co_return co_await std::move(future);
}(std::move(future));
auto fulfiller = [](auto promise) -> coro::Task<> {
promise.setValue(42);
co_return;
}(std::move(promise));
auto [res, _] = co_await coro::collectAll(
co_awaitTry(std::move(waiter)), std::move(fulfiller));
EXPECT_EQ(res.value(), 42);
}
CO_TEST(PromiseTest, SuspendException) {
auto [promise, future] = coro::makePromiseContract<int>();
auto waiter = [](auto future) -> coro::Task<int> {
co_return co_await std::move(future);
}(std::move(future));
auto fulfiller = [](auto promise) -> coro::Task<> {
promise.setException(std::logic_error(""));
co_return;
}(std::move(promise));
auto [res, _] = co_await coro::collectAll(
co_awaitTry(std::move(waiter)), std::move(fulfiller));
EXPECT_TRUE(res.hasException<std::logic_error>());
}
CO_TEST(PromiseTest, ImmediateCancel) {
auto [promise, future] = coro::makePromiseContract<int>();
CancellationSource cs;
cs.requestCancellation();
bool cancelled = false;
CancellationCallback cb{
promise.getCancellationToken(), [&] { cancelled = true; }};
EXPECT_FALSE(cancelled);
auto res = co_await co_awaitTry(
co_withCancellation(cs.getToken(), std::move(future)));
EXPECT_TRUE(cancelled);
EXPECT_TRUE(res.hasException<OperationCancelled>());
promise.setValue(42);
}
CO_TEST(PromiseTest, CancelFulfilled) {
auto [promise, future] = coro::makePromiseContract<int>();
promise.setValue(42);
CancellationSource cs;
cs.requestCancellation();
bool cancelled = false;
CancellationCallback cb{
promise.getCancellationToken(), [&] { cancelled = true; }};
auto res = co_await co_awaitTry(
co_withCancellation(cs.getToken(), std::move(future)));
EXPECT_FALSE(cancelled); // not signalled if already fulfilled
EXPECT_EQ(res.value(), 42);
}
CO_TEST(PromiseTest, SuspendCancel) {
auto [promise, future] = coro::makePromiseContract<int>();
CancellationSource cs;
bool cancelled = false;
CancellationCallback cb{
promise.getCancellationToken(), [&] { cancelled = true; }};
auto waiter = [](auto future) -> coro::Task<int> {
co_return co_await std::move(future);
}(co_withCancellation(cs.getToken(), std::move(future)));
auto fulfiller = [](auto cs) -> coro::Task<> {
cs.requestCancellation();
co_return;
}(cs);
auto [res, _] = co_await coro::collectAll(
co_awaitTry(std::move(waiter)), std::move(fulfiller));
EXPECT_TRUE(cancelled);
EXPECT_TRUE(res.hasException<OperationCancelled>());
}
CO_TEST(PromiseTest, ImmediateBreakPromise) {
auto [promise, future] = coro::makePromiseContract<int>();
{ auto p2 = std::move(promise); }
auto res = co_await co_awaitTry(std::move(future));
EXPECT_TRUE(res.hasException<BrokenPromise>());
}
CO_TEST(PromiseTest, SuspendBreakPromise) {
auto [promise, future] = coro::makePromiseContract<int>();
auto waiter = [](auto future) -> coro::Task<int> {
co_return co_await std::move(future);
}(std::move(future));
auto fulfiller = [](auto promise) -> coro::Task<> {
(void)promise;
co_return;
}(std::move(promise));
auto [res, _] = co_await coro::collectAll(
co_awaitTry(std::move(waiter)), std::move(fulfiller));
EXPECT_TRUE(res.hasException<BrokenPromise>());
}
CO_TEST(PromiseTest, Lifetime) {
struct Guard {
int& destroyed;
explicit Guard(int& d) : destroyed(d) {}
Guard(Guard&&) = default;
~Guard() { destroyed++; }
};
int destroyed = 0;
{
auto [promise, future] = coro::makePromiseContract<Guard>();
promise.setValue(Guard(destroyed));
EXPECT_EQ(destroyed, 1); // the temporary
co_await std::move(future);
EXPECT_EQ(destroyed, 2); // the return value
}
EXPECT_EQ(destroyed, 3); // the slot in shared state
}
TEST(PromiseTest, DropFuture) {
struct Guard {
int& destroyed;
explicit Guard(int& d) : destroyed(d) {}
Guard(Guard&&) = default;
~Guard() { destroyed++; }
};
int destroyed = 0;
{
auto [promise, future] = coro::makePromiseContract<Guard>();
promise.setValue(Guard(destroyed));
EXPECT_EQ(destroyed, 1); // the temporary
}
EXPECT_EQ(destroyed, 2); // the slot in shared state
}
CO_TEST(PromiseTest, MoveOnly) {
auto [promise, future] = coro::makePromiseContract<std::unique_ptr<int>>();
promise.setValue(std::make_unique<int>(42));
auto val = co_await std::move(future);
EXPECT_EQ(*val, 42);
}
CO_TEST(PromiseTest, Void) {
auto [promise, future] = coro::makePromiseContract<void>();
promise.setValue();
co_await std::move(future);
}
TEST(PromiseTest, IsReady) {
auto [promise, future] = coro::makePromiseContract<int>();
EXPECT_FALSE(future.isReady());
promise.setValue(42);
EXPECT_TRUE(future.isReady());
}
CO_TEST(PromiseTest, MakeFuture) {
auto future = coro::makeFuture(42);
EXPECT_TRUE(future.isReady());
auto val = co_await std::move(future);
EXPECT_EQ(val, 42);
auto future2 = coro::makeFuture<int>(std::runtime_error(""));
EXPECT_TRUE(future2.isReady());
auto res = co_await co_awaitTry(std::move(future2));
EXPECT_TRUE(res.hasException<std::runtime_error>());
auto future3 = coro::makeFuture();
EXPECT_TRUE(future3.isReady());
auto res3 = co_await co_awaitTry(std::move(future3));
EXPECT_TRUE(res3.hasValue());
}
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment