Commit 13560466 authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Add support for folly::coro::Task<T&>

Summary:
Extended `folly::coro::Task` to support an lvalue-reference value type.

A `Task<T&>` stores the value internally in a `folly::Try<std::reference_wrapper<T>>` as `folly::Try` does not support reference value types. This means that when you use `co_await task.co_awaitTry()` on a `Task<T&>` that you get back a `folly::Try<std::reference_wrapper<T>>` instead of a `folly::Try<T&>`.

Also added support for using `.co_awaitTry()` on `Task<T>`.
Previously the `.co_awaitTry()` method was only supported on the `TaskWithExecutor<T>` type.

Reviewed By: andriigrynenko

Differential Revision: D14657997

fbshipit-source-id: a168991b5a9c278dcc2387a6ff4fba9e47ccac0d
parent f5977149
...@@ -312,11 +312,12 @@ auto blockingWait(Awaitable&& awaitable) ...@@ -312,11 +312,12 @@ auto blockingWait(Awaitable&& awaitable)
template <typename T> template <typename T>
T blockingWait(Task<T>&& task) { T blockingWait(Task<T>&& task) {
using StorageType = detail::lift_lvalue_reference_t<T>;
folly::ManualExecutor executor; folly::ManualExecutor executor;
Try<T> ret; Try<StorageType> ret;
bool done{false}; bool done{false};
std::move(task).scheduleOn(&executor).start([&](Try<T>&& result) { std::move(task).scheduleOn(&executor).start([&](Try<StorageType>&& result) {
ret = std::move(result); ret = std::move(result);
done = true; done = true;
}); });
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <folly/experimental/coro/Utils.h> #include <folly/experimental/coro/Utils.h>
#include <folly/experimental/coro/ViaIfAsync.h> #include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/experimental/coro/detail/InlineTask.h> #include <folly/experimental/coro/detail/InlineTask.h>
#include <folly/experimental/coro/detail/Traits.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
#include <folly/io/async/Request.h> #include <folly/io/async/Request.h>
...@@ -98,6 +99,13 @@ class TaskPromiseBase { ...@@ -98,6 +99,13 @@ class TaskPromiseBase {
template <typename T> template <typename T>
class TaskPromise : public TaskPromiseBase { class TaskPromise : public TaskPromiseBase {
public: public:
static_assert(
!std::is_rvalue_reference_v<T>,
"Task<T&&> is not supported. "
"Consider using Task<T> or Task<std::unique_ptr<T>> instead.");
using StorageType = detail::lift_lvalue_reference_t<T>;
TaskPromise() noexcept = default; TaskPromise() noexcept = default;
Task<T> get_return_object() noexcept; Task<T> get_return_object() noexcept;
...@@ -110,27 +118,24 @@ class TaskPromise : public TaskPromiseBase { ...@@ -110,27 +118,24 @@ class TaskPromise : public TaskPromiseBase {
template <typename U> template <typename U>
void return_value(U&& value) { void return_value(U&& value) {
static_assert( static_assert(
std::is_convertible<U&&, T>::value, std::is_convertible<U&&, StorageType>::value,
"cannot convert return value to type T"); "cannot convert return value to type T");
result_.emplace(static_cast<U&&>(value)); result_.emplace(static_cast<U&&>(value));
} }
Try<T>& result() { Try<StorageType>& result() {
return result_; return result_;
} }
private: private:
using StorageType = std::conditional_t<
std::is_reference<T>::value,
std::reference_wrapper<std::remove_reference_t<T>>,
T>;
Try<StorageType> result_; Try<StorageType> result_;
}; };
template <> template <>
class TaskPromise<void> : public TaskPromiseBase { class TaskPromise<void> : public TaskPromiseBase {
public: public:
using StorageType = void;
TaskPromise() noexcept = default; TaskPromise() noexcept = default;
Task<void> get_return_object() noexcept; Task<void> get_return_object() noexcept;
...@@ -140,7 +145,9 @@ class TaskPromise<void> : public TaskPromiseBase { ...@@ -140,7 +145,9 @@ class TaskPromise<void> : public TaskPromiseBase {
exception_wrapper::from_exception_ptr(std::current_exception())); exception_wrapper::from_exception_ptr(std::current_exception()));
} }
void return_void() noexcept {} void return_void() noexcept {
result_.emplace();
}
Try<void>& result() { Try<void>& result() {
return result_; return result_;
...@@ -161,6 +168,7 @@ class TaskPromise<void> : public TaskPromiseBase { ...@@ -161,6 +168,7 @@ class TaskPromise<void> : public TaskPromiseBase {
template <typename T> template <typename T>
class FOLLY_NODISCARD TaskWithExecutor { class FOLLY_NODISCARD TaskWithExecutor {
using handle_t = std::experimental::coroutine_handle<detail::TaskPromise<T>>; using handle_t = std::experimental::coroutine_handle<detail::TaskPromise<T>>;
using StorageType = typename detail::TaskPromise<T>::StorageType;
public: public:
~TaskWithExecutor() { ~TaskWithExecutor() {
...@@ -188,12 +196,13 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -188,12 +196,13 @@ class FOLLY_NODISCARD TaskWithExecutor {
// Start execution of this task eagerly and return a folly::SemiFuture<T> // Start execution of this task eagerly and return a folly::SemiFuture<T>
// that will complete with the result. // that will complete with the result.
auto start() && { auto start() && {
Promise<lift_unit_t<T>> p; Promise<lift_unit_t<StorageType>> p;
auto sf = p.getSemiFuture(); auto sf = p.getSemiFuture();
std::move(*this).start([promise = std::move(p)](Try<T>&& result) mutable { std::move(*this).start(
promise.setTry(std::move(result)); [promise = std::move(p)](Try<StorageType>&& result) mutable {
}); promise.setTry(std::move(result));
});
return sf; return sf;
} }
...@@ -206,9 +215,9 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -206,9 +215,9 @@ class FOLLY_NODISCARD TaskWithExecutor {
try { try {
cb(co_await std::move(task).co_awaitTry()); cb(co_await std::move(task).co_awaitTry());
} catch (const std::exception& e) { } catch (const std::exception& e) {
cb(Try<T>(exception_wrapper(std::current_exception(), e))); cb(Try<StorageType>(exception_wrapper(std::current_exception(), e)));
} catch (...) { } catch (...) {
cb(Try<T>(exception_wrapper(std::current_exception()))); cb(Try<StorageType>(exception_wrapper(std::current_exception())));
} }
}(std::move(*this), std::forward<F>(tryCallback)); }(std::move(*this), std::forward<F>(tryCallback));
} }
...@@ -256,13 +265,13 @@ class FOLLY_NODISCARD TaskWithExecutor { ...@@ -256,13 +265,13 @@ class FOLLY_NODISCARD TaskWithExecutor {
}; };
struct ValueCreator { struct ValueCreator {
T operator()(Try<T>&& t) { T operator()(Try<StorageType>&& t) const {
return std::move(t).value(); return std::move(t).value();
} }
}; };
struct TryCreator { struct TryCreator {
Try<T> operator()(Try<T>&& t) { Try<StorageType> operator()(Try<StorageType>&& t) const {
return std::move(t); return std::move(t);
} }
}; };
...@@ -309,9 +318,12 @@ template <typename T> ...@@ -309,9 +318,12 @@ template <typename T>
class FOLLY_NODISCARD Task { class FOLLY_NODISCARD Task {
public: public:
using promise_type = detail::TaskPromise<T>; using promise_type = detail::TaskPromise<T>;
using StorageType = typename promise_type::StorageType;
private: private:
template <typename ResultCreator>
class Awaiter; class Awaiter;
class TrySemiAwaitable;
using handle_t = std::experimental::coroutine_handle<promise_type>; using handle_t = std::experimental::coroutine_handle<promise_type>;
public: public:
...@@ -344,25 +356,39 @@ class FOLLY_NODISCARD Task { ...@@ -344,25 +356,39 @@ class FOLLY_NODISCARD Task {
return TaskWithExecutor<T>{std::exchange(coro_, {})}; return TaskWithExecutor<T>{std::exchange(coro_, {})};
} }
SemiFuture<folly::lift_unit_t<T>> semi() && { SemiFuture<folly::lift_unit_t<StorageType>> semi() && {
return makeSemiFuture().defer( return makeSemiFuture().defer(
[task = std::move(*this)](Executor* executor, Try<Unit>&&) mutable { [task = std::move(*this)](Executor* executor, Try<Unit>&&) mutable {
return std::move(task).scheduleOn(executor).start(); return std::move(task).scheduleOn(executor).start();
}); });
} }
// Returns a SemiAwaitable<folly::Try<T>> type that when co_awaited will
// produce a Try<T> instead of the value T.
//
// eg.
// auto result = co_await std::move(someTask).co_awaitTry();
// if (result.hasValue()) {
// use(result.value());
// }
auto co_awaitTry() && noexcept {
return TrySemiAwaitable{std::exchange(coro_, {})};
}
friend auto co_viaIfAsync( friend auto co_viaIfAsync(
Executor::KeepAlive<> executor, Executor::KeepAlive<> executor,
Task<T>&& t) noexcept { Task<T>&& t) noexcept {
// Child task inherits the awaiting task's executor // Child task inherits the awaiting task's executor
t.coro_.promise().executor_ = std::move(executor); t.coro_.promise().executor_ = std::move(executor);
return Awaiter{std::exchange(t.coro_, {})}; return Awaiter<typename TaskWithExecutor<T>::ValueCreator>{
std::exchange(t.coro_, {})};
} }
private: private:
friend class detail::TaskPromiseBase; friend class detail::TaskPromiseBase;
friend class detail::TaskPromise<T>; friend class detail::TaskPromise<T>;
template <typename ResultCreator>
class Awaiter { class Awaiter {
public: public:
explicit Awaiter(handle_t coro) noexcept : coro_(coro) {} explicit Awaiter(handle_t coro) noexcept : coro_(coro) {}
...@@ -387,14 +413,49 @@ class FOLLY_NODISCARD Task { ...@@ -387,14 +413,49 @@ class FOLLY_NODISCARD Task {
return coro_; return coro_;
} }
T await_resume() { decltype(auto) await_resume() {
SCOPE_EXIT { SCOPE_EXIT {
std::exchange(coro_, {}).destroy(); std::exchange(coro_, {}).destroy();
}; };
return std::move(coro_.promise().result()).value(); return ResultCreator{}(std::move(coro_.promise().result()));
}
private:
handle_t coro_;
};
class TrySemiAwaitable {
public:
TrySemiAwaitable(TrySemiAwaitable&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~TrySemiAwaitable() {
if (coro_) {
coro_.destroy();
}
}
TrySemiAwaitable& operator=(TrySemiAwaitable&& other) noexcept {
auto oldCoro = std::exchange(coro_, std::exchange(other.coro_, {}));
if (oldCoro) {
oldCoro.destroy();
}
return *this;
}
friend auto co_viaIfAsync(
Executor::KeepAlive<> executor,
TrySemiAwaitable&& awaitable) noexcept {
awaitable.coro_.promise().executor_ = std::move(executor);
return Awaiter<typename TaskWithExecutor<T>::TryCreator>{
std::exchange(awaitable.coro_, {})};
} }
private: private:
friend Task<T>;
explicit TrySemiAwaitable(handle_t coro) noexcept : coro_(coro) {}
handle_t coro_; handle_t coro_;
}; };
......
...@@ -356,4 +356,27 @@ checkAwaitingFutureOfUnitDoesntWarnAboutDiscardedResult() { ...@@ -356,4 +356,27 @@ checkAwaitingFutureOfUnitDoesntWarnAboutDiscardedResult() {
co_await folly::futures::sleep(1ms); co_await folly::futures::sleep(1ms);
} }
folly::coro::Task<int&> returnIntRef(int& value) {
co_return value;
}
TEST(Task, TaskOfLvalueReference) {
int value = 123;
auto&& result = folly::coro::blockingWait(returnIntRef(value));
static_assert(std::is_same_v<decltype(result), int&>);
CHECK_EQ(&value, &result);
}
TEST(Task, TaskOfLvalueReferenceAsTry) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
int value = 123;
auto&& result = co_await returnIntRef(value).co_awaitTry();
CHECK(result.hasValue());
CHECK_EQ(&value, &result.value().get());
int& valueRef = co_await returnIntRef(value);
CHECK_EQ(&value, &valueRef);
}());
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment