Commit 4a92da5b authored by Hannes Roth's avatar Hannes Roth Committed by Praveen Kumar Ramakrishnan

(Wangle) Implement collect* using mapSetCallback and shared_ptrs

Summary:
I figured it would make sense to implement all the collect* functions using a shared_ptr<Context>, instead of doing our manual reference counting and all that. Fulfilling the promise in the destructor seemed like the icing on the cake. Also saves some line of code.

Test Plan: Run all the tests.

Reviewed By: hans@fb.com

Subscribers: folly-diffs@, jsedgwick, yfeldblum, chalfant

FB internal diff: D2015320

Signature: t1:2015320:1431106133:ac3001b3696fc75230afe70908ed349102b02a45
parent 5cc2f994
......@@ -531,22 +531,31 @@ inline Future<void> via(Executor* executor) {
return makeFuture().via(executor);
}
// when (variadic)
// mapSetCallback calls func(i, Try<T>) when every future completes
template <class T, class InputIterator, class F>
void mapSetCallback(InputIterator first, InputIterator last, F func) {
for (size_t i = 0; first != last; ++first, ++i) {
first->setCallback_([func, i](Try<T>&& t) {
func(i, std::move(t));
});
}
}
// collectAll (variadic)
template <typename... Fs>
typename detail::VariadicContext<
typename std::decay<Fs>::type::value_type...>::type
collectAll(Fs&&... fs) {
auto ctx =
new detail::VariadicContext<typename std::decay<Fs>::type::value_type...>();
ctx->total = sizeof...(fs);
auto f_saved = ctx->p.getFuture();
auto ctx = std::make_shared<detail::VariadicContext<
typename std::decay<Fs>::type::value_type...>>();
detail::collectAllVariadicHelper(ctx,
std::forward<typename std::decay<Fs>::type>(fs)...);
return f_saved;
return ctx->p.getFuture();
}
// when (iterator)
// collectAll (iterator)
template <class InputIterator>
Future<
......@@ -556,155 +565,87 @@ collectAll(InputIterator first, InputIterator last) {
typedef
typename std::iterator_traits<InputIterator>::value_type::value_type T;
if (first >= last) {
return makeFuture(std::vector<Try<T>>());
}
size_t n = std::distance(first, last);
auto ctx = new detail::WhenAllContext<T>();
ctx->results.resize(n);
auto f_saved = ctx->p.getFuture();
for (size_t i = 0; first != last; ++first, ++i) {
assert(i < n);
auto& f = *first;
f.setCallback_([ctx, i, n](Try<T> t) {
ctx->results[i] = std::move(t);
if (++ctx->count == n) {
ctx->p.setValue(std::move(ctx->results));
delete ctx;
}
});
}
struct CollectAllContext {
CollectAllContext(int n) : results(n) {}
~CollectAllContext() {
p.setValue(std::move(results));
}
Promise<std::vector<Try<T>>> p;
std::vector<Try<T>> results;
};
return f_saved;
auto ctx = std::make_shared<CollectAllContext>(std::distance(first, last));
mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
ctx->results[i] = std::move(t);
});
return ctx->p.getFuture();
}
namespace detail {
template <class, class, typename = void> struct CollectContextHelper;
template <class T, class VecT>
struct CollectContextHelper<T, VecT,
typename std::enable_if<std::is_same<T, VecT>::value>::type> {
static inline std::vector<T>&& getResults(std::vector<VecT>& results) {
return std::move(results);
}
};
template <class T, class VecT>
struct CollectContextHelper<T, VecT,
typename std::enable_if<!std::is_same<T, VecT>::value>::type> {
static inline std::vector<T> getResults(std::vector<VecT>& results) {
std::vector<T> finalResults;
finalResults.reserve(results.size());
for (auto& opt : results) {
finalResults.push_back(std::move(opt.value()));
}
return finalResults;
}
};
template <typename T>
struct CollectContext {
typedef typename std::conditional<
std::is_default_constructible<T>::value,
T,
Optional<T>
>::type VecT;
explicit CollectContext(int n) : count(0), success_count(0), threw(false) {
results.resize(n);
}
Promise<std::vector<T>> p;
std::vector<VecT> results;
std::atomic<size_t> count, success_count;
std::atomic_bool threw;
typedef std::vector<T> result_type;
static inline Future<std::vector<T>> makeEmptyFuture() {
return makeFuture(std::vector<T>());
}
inline void setValue() {
p.setValue(CollectContextHelper<T, VecT>::getResults(results));
struct Nothing { explicit Nothing(int n) {} };
using Result = typename std::conditional<
std::is_void<T>::value,
void,
std::vector<T>>::type;
using InternalResult = typename std::conditional<
std::is_void<T>::value,
Nothing,
std::vector<Optional<T>>>::type;
explicit CollectContext(int n) : result(n) {}
~CollectContext() {
if (!threw.exchange(true)) {
// map Optional<T> -> T
std::vector<T> finalResult;
finalResult.reserve(result.size());
std::transform(result.begin(), result.end(),
std::back_inserter(finalResult),
[](Optional<T>& o) { return std::move(o.value()); });
p.setValue(std::move(finalResult));
}
}
inline void addResult(int i, Try<T>& t) {
results[i] = std::move(t.value());
inline void setPartialResult(size_t i, Try<T>& t) {
result[i] = std::move(t.value());
}
Promise<Result> p;
InternalResult result;
std::atomic<bool> threw;
};
template <>
struct CollectContext<void> {
explicit CollectContext(int n) : count(0), success_count(0), threw(false) {}
// Specialize for void (implementations in Future.cpp)
Promise<void> p;
std::atomic<size_t> count, success_count;
std::atomic_bool threw;
typedef void result_type;
static inline Future<void> makeEmptyFuture() {
return makeFuture();
}
inline void setValue() {
p.setValue();
}
template <>
CollectContext<void>::~CollectContext();
inline void addResult(int i, Try<void>& t) {
// do nothing
}
};
template <>
void CollectContext<void>::setPartialResult(size_t i, Try<void>& t);
} // detail
}
template <class InputIterator>
Future<typename detail::CollectContext<
typename std::iterator_traits<InputIterator>::value_type::value_type
>::result_type>
typename std::iterator_traits<InputIterator>::value_type::value_type>::Result>
collect(InputIterator first, InputIterator last) {
typedef
typename std::iterator_traits<InputIterator>::value_type::value_type T;
if (first >= last) {
return detail::CollectContext<T>::makeEmptyFuture();
}
size_t n = std::distance(first, last);
auto ctx = new detail::CollectContext<T>(n);
auto f_saved = ctx->p.getFuture();
for (size_t i = 0; first != last; ++first, ++i) {
assert(i < n);
auto& f = *first;
f.setCallback_([ctx, i, n](Try<T> t) {
if (t.hasException()) {
if (!ctx->threw.exchange(true)) {
ctx->p.setException(std::move(t.exception()));
}
} else if (!ctx->threw) {
ctx->addResult(i, t);
if (++ctx->success_count == n) {
ctx->setValue();
}
}
if (++ctx->count == n) {
delete ctx;
auto ctx = std::make_shared<detail::CollectContext<T>>(
std::distance(first, last));
mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
if (t.hasException()) {
if (!ctx->threw.exchange(true)) {
ctx->p.setException(std::move(t.exception()));
}
});
}
return f_saved;
} else if (!ctx->threw) {
ctx->setPartialResult(i, t);
}
});
return ctx->p.getFuture();
}
template <class InputIterator>
......@@ -712,25 +653,24 @@ Future<
std::pair<size_t,
Try<
typename
std::iterator_traits<InputIterator>::value_type::value_type> > >
std::iterator_traits<InputIterator>::value_type::value_type>>>
collectAny(InputIterator first, InputIterator last) {
typedef
typename std::iterator_traits<InputIterator>::value_type::value_type T;
auto ctx = new detail::WhenAnyContext<T>(std::distance(first, last));
auto f_saved = ctx->p.getFuture();
for (size_t i = 0; first != last; first++, i++) {
auto& f = *first;
f.setCallback_([i, ctx](Try<T>&& t) {
if (!ctx->done.exchange(true)) {
ctx->p.setValue(std::make_pair(i, std::move(t)));
}
ctx->decref();
});
}
struct CollectAnyContext {
CollectAnyContext(size_t n) : done(false) {};
Promise<std::pair<size_t, Try<T>>> p;
std::atomic<bool> done;
};
return f_saved;
auto ctx = std::make_shared<CollectAnyContext>(std::distance(first, last));
mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
if (!ctx->done.exchange(true)) {
ctx->p.setValue(std::make_pair(i, std::move(t)));
}
});
return ctx->p.getFuture();
}
template <class InputIterator>
......@@ -741,38 +681,29 @@ collectN(InputIterator first, InputIterator last, size_t n) {
std::iterator_traits<InputIterator>::value_type::value_type T;
typedef std::vector<std::pair<size_t, Try<T>>> V;
struct ctx_t {
struct CollectNContext {
V v;
size_t completed;
std::atomic<size_t> completed = {0};
Promise<V> p;
};
auto ctx = std::make_shared<ctx_t>();
ctx->completed = 0;
// for each completed Future, increase count and add to vector, until we
// have n completed futures at which point we fulfill our Promise with the
// vector
auto it = first;
size_t i = 0;
while (it != last) {
it->then([ctx, n, i](Try<T>&& t) {
auto& v = ctx->v;
auto ctx = std::make_shared<CollectNContext>();
if (std::distance(first, last) < n) {
ctx->p.setException(std::runtime_error("Not enough futures"));
} else {
// for each completed Future, increase count and add to vector, until we
// have n completed futures at which point we fulfil our Promise with the
// vector
mapSetCallback<T>(first, last, [ctx, n](size_t i, Try<T>&& t) {
auto c = ++ctx->completed;
if (c <= n) {
assert(ctx->v.size() < n);
v.push_back(std::make_pair(i, std::move(t)));
ctx->v.push_back(std::make_pair(i, std::move(t)));
if (c == n) {
ctx->p.setTry(Try<V>(std::move(v)));
ctx->p.setTry(Try<V>(std::move(ctx->v)));
}
}
});
it++;
i++;
}
if (i < n) {
ctx->p.setException(std::runtime_error("Not enough futures"));
}
return ctx->p.getFuture();
......
......@@ -39,3 +39,19 @@ Future<void> sleep(Duration dur, Timekeeper* tk) {
}
}}
namespace folly { namespace detail {
template <>
CollectContext<void>::~CollectContext() {
if (!threw.exchange(true)) {
p.setValue();
}
}
template <>
void CollectContext<void>::setPartialResult(size_t i, Try<void>& t) {
// Nothing to do for void
}
}}
......@@ -319,59 +319,33 @@ class Core {
template <typename... Ts>
struct VariadicContext {
VariadicContext() : total(0), count(0) {}
Promise<std::tuple<Try<Ts>... > > p;
VariadicContext() {}
~VariadicContext() {
p.setValue(std::move(results));
}
Promise<std::tuple<Try<Ts>... >> p;
std::tuple<Try<Ts>... > results;
size_t total;
std::atomic<size_t> count;
typedef Future<std::tuple<Try<Ts>...>> type;
};
template <typename... Ts, typename THead, typename... Fs>
typename std::enable_if<sizeof...(Fs) == 0, void>::type
collectAllVariadicHelper(VariadicContext<Ts...> *ctx, THead&& head, Fs&&... tail) {
collectAllVariadicHelper(std::shared_ptr<VariadicContext<Ts...>> ctx,
THead&& head, Fs&&... tail) {
head.setCallback_([ctx](Try<typename THead::value_type>&& t) {
std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t);
if (++ctx->count == ctx->total) {
ctx->p.setValue(std::move(ctx->results));
delete ctx;
}
});
}
template <typename... Ts, typename THead, typename... Fs>
typename std::enable_if<sizeof...(Fs) != 0, void>::type
collectAllVariadicHelper(VariadicContext<Ts...> *ctx, THead&& head, Fs&&... tail) {
collectAllVariadicHelper(std::shared_ptr<VariadicContext<Ts...>> ctx,
THead&& head, Fs&&... tail) {
head.setCallback_([ctx](Try<typename THead::value_type>&& t) {
std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t);
if (++ctx->count == ctx->total) {
ctx->p.setValue(std::move(ctx->results));
delete ctx;
}
});
// template tail-recursion
collectAllVariadicHelper(ctx, std::forward<Fs>(tail)...);
}
template <typename T>
struct WhenAllContext {
WhenAllContext() : count(0) {}
Promise<std::vector<Try<T> > > p;
std::vector<Try<T> > results;
std::atomic<size_t> count;
};
template <typename T>
struct WhenAnyContext {
explicit WhenAnyContext(size_t n) : done(false), ref_count(n) {};
Promise<std::pair<size_t, Try<T>>> p;
std::atomic<bool> done;
std::atomic<size_t> ref_count;
void decref() {
if (--ref_count == 0) {
delete this;
}
}
};
}} // folly::detail
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment