Commit 93db3df4 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot 2

Allow adding tasks to TaskIterator dynamically

Reviewed By: yfeldblum

Differential Revision: D3244669

fb-gh-sync-id: 73fa4ecb0432a802e67ef922255a896d96f32374
fbshipit-source-id: 73fa4ecb0432a802e67ef922255a896d96f32374
parent 4598dd70
......@@ -16,20 +16,12 @@
#include <memory>
#include <vector>
#include <folly/experimental/fibers/FiberManager.h>
namespace folly {
namespace fibers {
template <typename T>
TaskIterator<T>::TaskIterator(TaskIterator&& other) noexcept
: context_(std::move(other.context_)), id_(other.id_) {}
template <typename T>
TaskIterator<T>::TaskIterator(std::shared_ptr<Context> context)
: context_(std::move(context)), id_(-1) {
assert(context_);
}
: context_(std::move(other.context_)), id_(other.id_), fm_(other.fm_) {}
template <typename T>
inline bool TaskIterator<T>::hasCompleted() const {
......@@ -92,6 +84,30 @@ inline size_t TaskIterator<T>::getTaskID() const {
return id_;
}
template <typename T>
template <typename F>
void TaskIterator<T>::addTask(F&& func) {
static_assert(
std::is_convertible<typename std::result_of<F()>::type, T>::value,
"TaskIterator<T>: T must be convertible from func()'s return type");
auto taskId = context_->totalTasks++;
fm_.addTask(
[ taskId, context = context_, func = std::forward<F>(func) ]() mutable {
context->results.emplace_back(
taskId, folly::makeTryWith(std::move(func)));
// Check for awaiting iterator.
if (context->promise.hasValue()) {
if (--context->tasksToFulfillPromise == 0) {
context->promise->setValue();
context->promise.clear();
}
}
});
}
template <class InputIterator>
TaskIterator<typename std::result_of<
typename std::iterator_traits<InputIterator>::value_type()>::type>
......@@ -101,32 +117,15 @@ addTasks(InputIterator first, InputIterator last) {
ResultType;
typedef TaskIterator<ResultType> IteratorType;
auto context = std::make_shared<typename IteratorType::Context>();
context->totalTasks = std::distance(first, last);
context->results.reserve(context->totalTasks);
for (size_t i = 0; first != last; ++i, ++first) {
#ifdef __clang__
#pragma clang diagnostic push // ignore generalized lambda capture warning
#pragma clang diagnostic ignored "-Wc++1y-extensions"
#endif
addTask([ i, context, f = std::move(*first) ]() {
context->results.emplace_back(i, folly::makeTryWith(std::move(f)));
// Check for awaiting iterator.
if (context->promise.hasValue()) {
if (--context->tasksToFulfillPromise == 0) {
context->promise->setValue();
context->promise.clear();
}
}
});
#ifdef __clang__
#pragma clang diagnostic pop
#endif
IteratorType iterator;
for (; first != last; ++first) {
iterator.addTask(std::move(*first));
}
return IteratorType(std::move(context));
iterator.context_->results.reserve(iterator.context_->totalTasks);
return std::move(iterator);
}
}
}
......@@ -19,6 +19,7 @@
#include <vector>
#include <folly/Optional.h>
#include <folly/experimental/fibers/FiberManager.h>
#include <folly/experimental/fibers/Promise.h>
#include <folly/futures/Try.h>
......@@ -49,6 +50,8 @@ class TaskIterator {
public:
typedef T value_type;
TaskIterator() : fm_(FiberManager::getFiberManager()) {}
// not copyable
TaskIterator(const TaskIterator& other) = delete;
TaskIterator& operator=(const TaskIterator& other) = delete;
......@@ -57,6 +60,14 @@ class TaskIterator {
TaskIterator(TaskIterator&& other) noexcept;
TaskIterator& operator=(TaskIterator&& other) = delete;
/**
* Add one more task to the TaskIterator.
*
* @param func task to be added, will be scheduled on current FiberManager
*/
template <typename F>
void addTask(F&& func);
/**
* @return True if there are tasks immediately available to be consumed (no
* need to await on them).
......@@ -111,10 +122,9 @@ class TaskIterator {
size_t tasksToFulfillPromise{0};
};
std::shared_ptr<Context> context_;
size_t id_;
explicit TaskIterator(std::shared_ptr<Context> context);
std::shared_ptr<Context> context_{std::make_shared<Context>()};
size_t id_{std::numeric_limits<size_t>::max()};
FiberManager& fm_;
folly::Try<T> awaitNextResult();
};
......
......@@ -463,7 +463,7 @@ TEST(FiberManager, addTasksVoidThrow) {
loopController.loop(std::move(loopFunc));
}
TEST(FiberManager, reserve) {
TEST(FiberManager, addTasksReserve) {
std::vector<Promise<int>> pendingFibers;
bool taskAdded = false;
......@@ -517,6 +517,42 @@ TEST(FiberManager, reserve) {
loopController.loop(std::move(loopFunc));
}
TEST(FiberManager, addTaskDynamic) {
folly::EventBase evb;
Baton batons[3];
auto makeTask = [&](size_t taskId) {
return [&, taskId]() -> size_t {
batons[taskId].wait();
return taskId;
};
};
getFiberManager(evb)
.addTaskFuture([&]() {
TaskIterator<size_t> iterator;
iterator.addTask(makeTask(0));
iterator.addTask(makeTask(1));
batons[1].post();
EXPECT_EQ(1, iterator.awaitNext());
iterator.addTask(makeTask(2));
batons[2].post();
EXPECT_EQ(2, iterator.awaitNext());
batons[0].post();
EXPECT_EQ(0, iterator.awaitNext());
})
.waitVia(&evb);
}
TEST(FiberManager, forEach) {
std::vector<Promise<int>> pendingFibers;
bool taskAdded = false;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment