Commit d0494141 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Fix collectSemiFuture and collectN to work with SemiFutures with deferred work

Reviewed By: yfeldblum

Differential Revision: D14719902

fbshipit-source-id: d3bc7b61d666c796c3c19f7fa198252acb472970
parent 869662db
...@@ -1594,6 +1594,10 @@ collectSemiFuture(InputIterator first, InputIterator last) { ...@@ -1594,6 +1594,10 @@ collectSemiFuture(InputIterator first, InputIterator last) {
std::atomic<bool> threw{false}; std::atomic<bool> threw{false};
}; };
std::vector<folly::Executor::KeepAlive<futures::detail::DeferredExecutor>>
executors;
futures::detail::stealDeferredExecutors(executors, first, last);
auto ctx = std::make_shared<Context>(std::distance(first, last)); auto ctx = std::make_shared<Context>(std::distance(first, last));
for (size_t i = 0; first != last; ++first, ++i) { for (size_t i = 0; first != last; ++first, ++i) {
first->setCallback_([i, ctx](Try<T>&& t) { first->setCallback_([i, ctx](Try<T>&& t) {
...@@ -1606,7 +1610,17 @@ collectSemiFuture(InputIterator first, InputIterator last) { ...@@ -1606,7 +1610,17 @@ collectSemiFuture(InputIterator first, InputIterator last) {
} }
}); });
} }
return ctx->p.getSemiFuture();
auto future = ctx->p.getSemiFuture();
if (!executors.empty()) {
auto work = [](Try<typename decltype(future)::value_type>&& t) {
return std::move(t).value();
};
future = std::move(future).defer(work);
auto deferredExecutor = futures::detail::getDeferredExecutor(future);
deferredExecutor->setNestedExecutors(std::move(executors));
}
return future;
} }
template <class InputIterator> template <class InputIterator>
...@@ -1634,6 +1648,10 @@ collectSemiFuture(Fs&&... fs) { ...@@ -1634,6 +1648,10 @@ collectSemiFuture(Fs&&... fs) {
std::atomic<bool> threw{false}; std::atomic<bool> threw{false};
}; };
std::vector<folly::Executor::KeepAlive<futures::detail::DeferredExecutor>>
executors;
futures::detail::stealDeferredExecutorsVariadic(executors, fs...);
auto ctx = std::make_shared<Context>(); auto ctx = std::make_shared<Context>();
futures::detail::foreach( futures::detail::foreach(
[&](auto i, auto&& f) { [&](auto i, auto&& f) {
...@@ -1648,7 +1666,17 @@ collectSemiFuture(Fs&&... fs) { ...@@ -1648,7 +1666,17 @@ collectSemiFuture(Fs&&... fs) {
}); });
}, },
static_cast<Fs&&>(fs)...); static_cast<Fs&&>(fs)...);
return ctx->p.getSemiFuture();
auto future = ctx->p.getSemiFuture();
if (!executors.empty()) {
auto work = [](Try<typename decltype(future)::value_type>&& t) {
return std::move(t).value();
};
future = std::move(future).defer(work);
auto deferredExecutor = futures::detail::getDeferredExecutor(future);
deferredExecutor->setNestedExecutors(std::move(executors));
}
return future;
} }
template <typename... Fs> template <typename... Fs>
...@@ -1657,6 +1685,12 @@ Future<std::tuple<typename remove_cvref_t<Fs>::value_type...>> collect( ...@@ -1657,6 +1685,12 @@ Future<std::tuple<typename remove_cvref_t<Fs>::value_type...>> collect(
return collectSemiFuture(std::forward<Fs>(fs)...).toUnsafeFuture(); return collectSemiFuture(std::forward<Fs>(fs)...).toUnsafeFuture();
} }
template <class Collection>
auto collectSemiFuture(Collection&& c)
-> decltype(collectSemiFuture(c.begin(), c.end())) {
return collectSemiFuture(c.begin(), c.end());
}
// collectAny (iterator) // collectAny (iterator)
// TODO(T26439406): Make return SemiFuture // TODO(T26439406): Make return SemiFuture
...@@ -1762,6 +1796,10 @@ collectN(InputIterator first, InputIterator last, size_t n) { ...@@ -1762,6 +1796,10 @@ collectN(InputIterator first, InputIterator last, size_t n) {
exception_wrapper(std::runtime_error("Not enough futures"))); exception_wrapper(std::runtime_error("Not enough futures")));
} }
std::vector<folly::Executor::KeepAlive<futures::detail::DeferredExecutor>>
executors;
futures::detail::stealDeferredExecutors(executors, first, last);
// for each completed Future, increase count and add to vector, until we // for each completed Future, increase count and add to vector, until we
// have n completed futures at which point we fulfil our Promise with the // have n completed futures at which point we fulfil our Promise with the
// vector // vector
...@@ -1793,7 +1831,16 @@ collectN(InputIterator first, InputIterator last, size_t n) { ...@@ -1793,7 +1831,16 @@ collectN(InputIterator first, InputIterator last, size_t n) {
}); });
} }
return ctx->p.getSemiFuture(); auto future = ctx->p.getSemiFuture();
if (!executors.empty()) {
future = std::move(future).defer(
[](Try<typename decltype(future)::value_type>&& t) {
return std::move(t).value();
});
auto deferredExecutor = futures::detail::getDeferredExecutor(future);
deferredExecutor->setNestedExecutors(std::move(executors));
}
return future;
} }
// reduce (iterator) // reduce (iterator)
......
...@@ -1058,6 +1058,106 @@ TEST(SemiFuture, collectAllSemiFutureDeferredWork) { ...@@ -1058,6 +1058,106 @@ TEST(SemiFuture, collectAllSemiFutureDeferredWork) {
} }
} }
TEST(SemiFuture, collectSemiFutureDeferredWork) {
{
Promise<int> promise1;
Promise<int> promise2;
auto future = collectSemiFuture(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }),
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
promise1.setValue(1);
promise2.setValue(2);
auto result = std::move(future).getTry(std::chrono::milliseconds{100});
EXPECT_TRUE(result.hasValue());
EXPECT_EQ(2, std::get<0>(*result));
EXPECT_EQ(4, std::get<1>(*result));
}
{
Promise<int> promise1;
Promise<int> promise2;
auto future = collectSemiFuture(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }),
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
promise1.setValue(1);
promise2.setValue(2);
ManualExecutor executor;
auto value = std::move(future).via(&executor).getVia(&executor);
EXPECT_EQ(2, std::get<0>(value));
EXPECT_EQ(4, std::get<1>(value));
}
{
Promise<int> promise1;
Promise<int> promise2;
std::vector<SemiFuture<int>> futures;
futures.push_back(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }));
futures.push_back(
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
auto future = collectSemiFuture(futures);
promise1.setValue(1);
promise2.setValue(2);
EXPECT_TRUE(future.wait().isReady());
auto value = std::move(future).get();
EXPECT_EQ(2, value[0]);
EXPECT_EQ(4, value[1]);
}
{
bool deferredDestroyed = false;
{
Promise<int> promise;
auto guard = makeGuard([&] { deferredDestroyed = true; });
collectSemiFuture(promise.getSemiFuture().deferValue(
[guard = std::move(guard)](int x) { return x; }));
}
EXPECT_TRUE(deferredDestroyed);
}
}
TEST(SemiFuture, collectNDeferredWork) {
Promise<int> promise1;
Promise<int> promise2;
Promise<int> promise3;
std::vector<SemiFuture<int>> futures;
futures.push_back(
promise1.getSemiFuture().deferValue([](int x) { return x * 2; }));
futures.push_back(
promise2.getSemiFuture().deferValue([](int x) { return x * 2; }));
futures.push_back(
promise3.getSemiFuture().deferValue([](int x) { return x * 2; }));
auto future = collectN(std::move(futures), 2);
promise1.setValue(1);
promise3.setValue(3);
EXPECT_TRUE(future.wait().isReady());
auto value = std::move(future).get();
EXPECT_EQ(2, *value[0].second);
EXPECT_EQ(6, *value[1].second);
}
TEST(SemiFuture, DeferWithNestedSemiFuture) { TEST(SemiFuture, DeferWithNestedSemiFuture) {
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
auto future = futures::sleep(std::chrono::milliseconds{100}) auto future = futures::sleep(std::chrono::milliseconds{100})
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment