Commit 7a111f40 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Fix collectAllSemiFuture to work for SemiFutures with deferred work

Summary: We have to extend DeferredExecutor to be aware of nested DeferredExecutor, so that it can propagate various signals to them.

Reviewed By: LeeHowes

Differential Revision: D8600961

fbshipit-source-id: 2b9520598cad65bd80f0346e6f76e224903780a1
parent 3bb85dae
...@@ -540,6 +540,11 @@ class DeferredExecutor final : public Executor { ...@@ -540,6 +540,11 @@ class DeferredExecutor final : public Executor {
} }
void setExecutor(folly::Executor* executor) { void setExecutor(folly::Executor* executor) {
if (nestedExecutors_) {
for (auto nestedExecutor : *nestedExecutors_) {
nestedExecutor->setExecutor(executor);
}
}
executor_ = executor; executor_ = executor;
auto state = state_.load(std::memory_order_acquire); auto state = state_.load(std::memory_order_acquire);
do { do {
...@@ -562,6 +567,11 @@ class DeferredExecutor final : public Executor { ...@@ -562,6 +567,11 @@ class DeferredExecutor final : public Executor {
} }
void detach() { void detach() {
if (nestedExecutors_) {
for (auto nestedExecutor : *nestedExecutors_) {
nestedExecutor->detach();
}
}
auto state = state_.load(std::memory_order_acquire); auto state = state_.load(std::memory_order_acquire);
do { do {
if (state == State::HAS_FUNCTION) { if (state == State::HAS_FUNCTION) {
...@@ -581,6 +591,13 @@ class DeferredExecutor final : public Executor { ...@@ -581,6 +591,13 @@ class DeferredExecutor final : public Executor {
} }
void wait() { void wait() {
if (nestedExecutors_) {
for (auto nestedExecutor : *nestedExecutors_) {
nestedExecutor->wait();
nestedExecutor->runAndDestroy();
}
return;
}
auto state = state_.load(std::memory_order_acquire); auto state = state_.load(std::memory_order_acquire);
auto baton = std::make_shared<FutureBatonType>(); auto baton = std::make_shared<FutureBatonType>();
baton_ = baton; baton_ = baton;
...@@ -600,7 +617,24 @@ class DeferredExecutor final : public Executor { ...@@ -600,7 +617,24 @@ class DeferredExecutor final : public Executor {
assert(state_.load(std::memory_order_relaxed) == State::HAS_FUNCTION); assert(state_.load(std::memory_order_relaxed) == State::HAS_FUNCTION);
} }
using Clock = std::chrono::steady_clock;
bool wait(Duration duration) { bool wait(Duration duration) {
return wait_until(Clock::now() + duration);
}
bool wait_until(Clock::time_point deadline) {
if (nestedExecutors_) {
for (auto nestedExecutor : *nestedExecutors_) {
if (!nestedExecutor->wait_until(deadline)) {
return false;
}
}
for (auto nestedExecutor : *nestedExecutors_) {
nestedExecutor->runAndDestroy();
}
}
auto state = state_.load(std::memory_order_acquire); auto state = state_.load(std::memory_order_acquire);
auto baton = std::make_shared<FutureBatonType>(); auto baton = std::make_shared<FutureBatonType>();
baton_ = baton; baton_ = baton;
...@@ -615,7 +649,7 @@ class DeferredExecutor final : public Executor { ...@@ -615,7 +649,7 @@ class DeferredExecutor final : public Executor {
std::memory_order_release, std::memory_order_release,
std::memory_order_acquire)); std::memory_order_acquire));
if (baton->try_wait_for(duration)) { if (baton->try_wait_until(deadline)) {
assert(state_.load(std::memory_order_relaxed) == State::HAS_FUNCTION); assert(state_.load(std::memory_order_relaxed) == State::HAS_FUNCTION);
return true; return true;
} }
...@@ -634,6 +668,12 @@ class DeferredExecutor final : public Executor { ...@@ -634,6 +668,12 @@ class DeferredExecutor final : public Executor {
return false; return false;
} }
void setNestedExecutors(std::vector<DeferredExecutor*> executors) {
DCHECK(!nestedExecutors_);
nestedExecutors_ =
std::make_unique<std::vector<DeferredExecutor*>>(executors);
}
private: private:
enum class State { enum class State {
EMPTY, EMPTY,
...@@ -646,6 +686,7 @@ class DeferredExecutor final : public Executor { ...@@ -646,6 +686,7 @@ class DeferredExecutor final : public Executor {
Func func_; Func func_;
Executor* executor_; Executor* executor_;
folly::Synchronized<std::shared_ptr<FutureBatonType>> baton_; folly::Synchronized<std::shared_ptr<FutureBatonType>> baton_;
std::unique_ptr<std::vector<DeferredExecutor*>> nestedExecutors_;
}; };
// Vector-like structure to play with window, // Vector-like structure to play with window,
...@@ -747,6 +788,17 @@ typename SemiFuture<T>::DeferredExecutor* SemiFuture<T>::getDeferredExecutor() ...@@ -747,6 +788,17 @@ typename SemiFuture<T>::DeferredExecutor* SemiFuture<T>::getDeferredExecutor()
return nullptr; return nullptr;
} }
template <class T>
typename SemiFuture<T>::DeferredExecutor* SemiFuture<T>::stealDeferredExecutor()
const {
if (auto executor = this->getExecutor()) {
assert(dynamic_cast<DeferredExecutor*>(executor) != nullptr);
this->core_->setExecutor(nullptr);
return static_cast<DeferredExecutor*>(executor);
}
return nullptr;
}
template <class T> template <class T>
void SemiFuture<T>::releaseDeferredExecutor(Core* core) { void SemiFuture<T>::releaseDeferredExecutor(Core* core) {
if (!core) { if (!core) {
...@@ -1312,6 +1364,45 @@ FOLLY_ALWAYS_INLINE FOLLY_ATTR_VISIBILITY_HIDDEN void foreach( ...@@ -1312,6 +1364,45 @@ FOLLY_ALWAYS_INLINE FOLLY_ATTR_VISIBILITY_HIDDEN void foreach(
foreach_(_{}, static_cast<V&&>(v), static_cast<Fs&&>(fs)...); foreach_(_{}, static_cast<V&&>(v), static_cast<Fs&&>(fs)...);
} }
template <typename T>
DeferredExecutor* getDeferredExecutor(SemiFuture<T>& future) {
return future.getDeferredExecutor();
}
template <typename T>
DeferredExecutor* stealDeferredExecutor(SemiFuture<T>& future) {
return future.stealDeferredExecutor();
}
template <typename T>
DeferredExecutor* stealDeferredExecutor(Future<T>&) {
return nullptr;
}
template <typename... Ts>
void stealDeferredExecutorsVariadic(
std::vector<DeferredExecutor*>& executors,
Ts&... ts) {
auto foreach = [&](auto& future) {
if (auto executor = stealDeferredExecutor(future)) {
executors.push_back(executor);
}
return folly::unit;
};
[](...) {}(foreach(ts)...);
}
template <class InputIterator>
void stealDeferredExecutors(
std::vector<DeferredExecutor*>& executors,
InputIterator first,
InputIterator last) {
for (auto it = first; it != last; ++it) {
if (auto executor = stealDeferredExecutor(*it)) {
executors.push_back(executor);
}
}
}
} // namespace detail } // namespace detail
} // namespace futures } // namespace futures
...@@ -1329,6 +1420,9 @@ collectAllSemiFuture(Fs&&... fs) { ...@@ -1329,6 +1420,9 @@ collectAllSemiFuture(Fs&&... fs) {
Result results; Result results;
}; };
std::vector<futures::detail::DeferredExecutor*> executors;
futures::detail::stealDeferredExecutorsVariadic(executors, fs...);
auto ctx = std::make_shared<Context>(); auto ctx = std::make_shared<Context>();
futures::detail::foreach( futures::detail::foreach(
[&](auto i, auto&& f) { [&](auto i, auto&& f) {
...@@ -1337,7 +1431,17 @@ collectAllSemiFuture(Fs&&... fs) { ...@@ -1337,7 +1431,17 @@ collectAllSemiFuture(Fs&&... fs) {
}); });
}, },
static_cast<Fs&&>(fs)...); static_cast<Fs&&>(fs)...);
return ctx->p.getSemiFuture();
auto future = ctx->p.getSemiFuture();
if (!executors.empty()) {
future = std::move(future).defer(
[](Try<typename decltype(future)::value_type>&& t) {
return std::move(t).value();
});
auto deferredExecutor = futures::detail::getDeferredExecutor(future);
deferredExecutor->setNestedExecutors(std::move(executors));
}
return future;
} }
template <typename... Fs> template <typename... Fs>
...@@ -1364,12 +1468,26 @@ collectAllSemiFuture(InputIterator first, InputIterator last) { ...@@ -1364,12 +1468,26 @@ collectAllSemiFuture(InputIterator first, InputIterator last) {
std::vector<Try<T>> results; std::vector<Try<T>> results;
}; };
std::vector<futures::detail::DeferredExecutor*> executors;
futures::detail::stealDeferredExecutors(executors, first, last);
auto ctx = std::make_shared<Context>(size_t(std::distance(first, last))); auto ctx = std::make_shared<Context>(size_t(std::distance(first, last)));
for (size_t i = 0; first != last; ++first, ++i) { for (size_t i = 0; first != last; ++first, ++i) {
first->setCallback_( first->setCallback_(
[i, ctx](Try<T>&& t) { ctx->results[i] = std::move(t); }); [i, ctx](Try<T>&& t) { ctx->results[i] = std::move(t); });
} }
return ctx->p.getSemiFuture();
auto future = ctx->p.getSemiFuture();
if (!executors.empty()) {
future = std::move(future).defer(
[](Try<typename decltype(future)::value_type>&& t) {
return std::move(t).value();
});
auto deferredExecutor = futures::detail::getDeferredExecutor(future);
deferredExecutor->setNestedExecutors(std::move(executors));
}
return future;
} }
template <class InputIterator> template <class InputIterator>
......
...@@ -434,6 +434,14 @@ class FutureBase { ...@@ -434,6 +434,14 @@ class FutureBase {
}; };
template <class T> template <class T>
void convertFuture(SemiFuture<T>&& sf, Future<T>& f); void convertFuture(SemiFuture<T>&& sf, Future<T>& f);
class DeferredExecutor;
template <typename T>
DeferredExecutor* getDeferredExecutor(SemiFuture<T>& future);
template <typename T>
DeferredExecutor* stealDeferredExecutor(SemiFuture<T>& future);
} // namespace detail } // namespace detail
} // namespace futures } // namespace futures
...@@ -878,6 +886,9 @@ class SemiFuture : private futures::detail::FutureBase<T> { ...@@ -878,6 +886,9 @@ class SemiFuture : private futures::detail::FutureBase<T> {
friend class SemiFuture; friend class SemiFuture;
template <class> template <class>
friend class Future; friend class Future;
friend DeferredExecutor* futures::detail::stealDeferredExecutor<T>(
SemiFuture&);
friend DeferredExecutor* futures::detail::getDeferredExecutor<T>(SemiFuture&);
using Base::setExecutor; using Base::setExecutor;
using Base::throwIfInvalid; using Base::throwIfInvalid;
...@@ -894,6 +905,9 @@ class SemiFuture : private futures::detail::FutureBase<T> { ...@@ -894,6 +905,9 @@ class SemiFuture : private futures::detail::FutureBase<T> {
// Throws FutureInvalid if !this->core_ // Throws FutureInvalid if !this->core_
DeferredExecutor* getDeferredExecutor() const; DeferredExecutor* getDeferredExecutor() const;
// Throws FutureInvalid if !this->core_
DeferredExecutor* stealDeferredExecutor() const;
static void releaseDeferredExecutor(Core* core); static void releaseDeferredExecutor(Core* core);
}; };
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <folly/Memory.h> #include <folly/Memory.h>
#include <folly/Unit.h> #include <folly/Unit.h>
#include <folly/dynamic.h> #include <folly/dynamic.h>
#include <folly/executors/ManualExecutor.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
...@@ -1149,3 +1150,77 @@ TEST(SemiFuture, semiFutureWithinNoValueReferenceWhenTimeOut) { ...@@ -1149,3 +1150,77 @@ TEST(SemiFuture, semiFutureWithinNoValueReferenceWhenTimeOut) {
EXPECT_EQ(0, callbackInput.value().use_count()); EXPECT_EQ(0, callbackInput.value().use_count());
}); });
} }
TEST(SemiFuture, collectAllSemiFutureDeferredWork) {
{
Promise<int> promise1;
Promise<int> promise2;
auto future = collectAllSemiFuture(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }),
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
promise1.setValue(1);
promise2.setValue(2);
EXPECT_TRUE(future.wait(std::chrono::milliseconds{100}).isReady());
auto value = std::move(future).get();
EXPECT_EQ(2, *std::get<0>(value));
EXPECT_EQ(4, *std::get<1>(value));
}
{
Promise<int> promise1;
Promise<int> promise2;
auto future = collectAllSemiFuture(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }),
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
promise1.setValue(1);
promise2.setValue(2);
ManualExecutor executor;
auto value = std::move(future).via(&executor).getVia(&executor);
EXPECT_EQ(2, *std::get<0>(value));
EXPECT_EQ(4, *std::get<1>(value));
}
{
Promise<int> promise1;
Promise<int> promise2;
std::vector<SemiFuture<int>> futures;
futures.push_back(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }));
futures.push_back(
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
auto future = collectAllSemiFuture(futures);
promise1.setValue(1);
promise2.setValue(2);
EXPECT_TRUE(future.wait().isReady());
auto value = std::move(future).get();
EXPECT_EQ(2, *value[0]);
EXPECT_EQ(4, *value[1]);
}
{
bool deferredDestroyed = false;
{
Promise<int> promise;
auto guard = makeGuard([&] { deferredDestroyed = true; });
collectAllSemiFuture(promise.getSemiFuture().deferValue(
[guard = std::move(guard)](int x) { return x; }));
}
EXPECT_TRUE(deferredDestroyed);
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment