Commit 8044c14f authored by Andrew Smith's avatar Andrew Smith Committed by Facebook Github Bot

Adding coroutine-based generator to folly

Summary: This diff adds the generator and recursive_generator classes from the cppcoro library to folly.

Reviewed By: lewissbaker

Differential Revision: D13834297

fbshipit-source-id: d27d11c39a35749c168a7c8e53c2819c36083467
parent 340d437c
/*
* Copyright 2019-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <exception>
#include <experimental/coroutine>
#include <type_traits>
#include <utility>
namespace folly {
namespace coro {
template <typename T>
class Generator {
public:
class promise_type final {
public:
promise_type() noexcept
: m_value(nullptr),
m_exception(nullptr),
m_root(this),
m_parentOrLeaf(this) {}
promise_type(const promise_type&) = delete;
promise_type(promise_type&&) = delete;
auto get_return_object() noexcept {
return Generator<T>{*this};
}
std::experimental::suspend_always initial_suspend() noexcept {
return {};
}
std::experimental::suspend_always final_suspend() noexcept {
return {};
}
void unhandled_exception() noexcept {
m_exception = std::current_exception();
}
void return_void() noexcept {}
std::experimental::suspend_always yield_value(T& value) noexcept {
m_value = std::addressof(value);
return {};
}
std::experimental::suspend_always yield_value(T&& value) noexcept {
m_value = std::addressof(value);
return {};
}
auto yield_value(Generator&& generator) noexcept {
return yield_value(generator);
}
auto yield_value(Generator& generator) noexcept {
struct awaitable {
awaitable(promise_type* childPromise) : m_childPromise(childPromise) {}
bool await_ready() noexcept {
return this->m_childPromise == nullptr;
}
void await_suspend(
std::experimental::coroutine_handle<promise_type>) noexcept {}
void await_resume() {
if (this->m_childPromise != nullptr) {
this->m_childPromise->throw_if_exception();
}
}
private:
promise_type* m_childPromise;
};
if (generator.m_promise != nullptr) {
m_root->m_parentOrLeaf = generator.m_promise;
generator.m_promise->m_root = m_root;
generator.m_promise->m_parentOrLeaf = this;
generator.m_promise->resume();
if (!generator.m_promise->is_complete()) {
return awaitable{generator.m_promise};
}
m_root->m_parentOrLeaf = this;
}
return awaitable{nullptr};
}
// Don't allow any use of 'co_await' inside the Generator
// coroutine.
template <typename U>
std::experimental::suspend_never await_transform(U&& value) = delete;
void destroy() noexcept {
std::experimental::coroutine_handle<promise_type>::from_promise(*this)
.destroy();
}
void throw_if_exception() {
if (m_exception != nullptr) {
std::rethrow_exception(std::move(m_exception));
}
}
bool is_complete() noexcept {
return std::experimental::coroutine_handle<promise_type>::from_promise(
*this)
.done();
}
T& value() noexcept {
assert(this == m_root);
assert(!is_complete());
return *(m_parentOrLeaf->m_value);
}
void pull() noexcept {
assert(this == m_root);
assert(!m_parentOrLeaf->is_complete());
m_parentOrLeaf->resume();
while (m_parentOrLeaf != this && m_parentOrLeaf->is_complete()) {
m_parentOrLeaf = m_parentOrLeaf->m_parentOrLeaf;
m_parentOrLeaf->resume();
}
}
private:
void resume() noexcept {
std::experimental::coroutine_handle<promise_type>::from_promise(*this)
.resume();
}
std::add_pointer_t<T> m_value;
std::exception_ptr m_exception;
promise_type* m_root;
// If this is the promise of the root generator then this field
// is a pointer to the leaf promise.
// For non-root generators this is a pointer to the parent promise.
promise_type* m_parentOrLeaf;
};
Generator() noexcept : m_promise(nullptr) {}
Generator(promise_type& promise) noexcept : m_promise(&promise) {}
Generator(Generator&& other) noexcept : m_promise(other.m_promise) {
other.m_promise = nullptr;
}
Generator(const Generator& other) = delete;
Generator& operator=(const Generator& other) = delete;
~Generator() {
if (m_promise != nullptr) {
m_promise->destroy();
}
}
Generator& operator=(Generator&& other) noexcept {
if (this != &other) {
if (m_promise != nullptr) {
m_promise->destroy();
}
m_promise = other.m_promise;
other.m_promise = nullptr;
}
return *this;
}
class iterator {
public:
using iterator_category = std::input_iterator_tag;
// What type should we use for counting elements of a potentially infinite
// sequence?
using difference_type = std::ptrdiff_t;
using value_type = std::remove_reference_t<T>;
using reference = std::conditional_t<std::is_reference_v<T>, T, T&>;
using pointer = std::add_pointer_t<T>;
iterator() noexcept : m_promise(nullptr) {}
explicit iterator(promise_type* promise) noexcept : m_promise(promise) {}
bool operator==(const iterator& other) const noexcept {
return m_promise == other.m_promise;
}
bool operator!=(const iterator& other) const noexcept {
return m_promise != other.m_promise;
}
iterator& operator++() {
assert(m_promise != nullptr);
assert(!m_promise->is_complete());
m_promise->pull();
if (m_promise->is_complete()) {
auto* temp = m_promise;
m_promise = nullptr;
temp->throw_if_exception();
}
return *this;
}
void operator++(int) {
(void)operator++();
}
reference operator*() const noexcept {
assert(m_promise != nullptr);
return static_cast<reference>(m_promise->value());
}
pointer operator->() const noexcept {
return std::addressof(operator*());
}
private:
promise_type* m_promise;
};
iterator begin() {
if (m_promise != nullptr) {
m_promise->pull();
if (!m_promise->is_complete()) {
return iterator(m_promise);
}
m_promise->throw_if_exception();
}
return iterator(nullptr);
}
iterator end() noexcept {
return iterator(nullptr);
}
void swap(Generator& other) noexcept {
std::swap(m_promise, other.m_promise);
}
private:
friend class promise_type;
promise_type* m_promise;
};
template <typename T>
void swap(Generator<T>& a, Generator<T>& b) noexcept {
a.swap(b);
}
} // namespace coro
} // namespace folly
/*
* Copyright 2019-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/ScopeGuard.h>
#include <folly/experimental/coro/Generator.h>
#include <folly/portability/GTest.h>
#include <algorithm>
namespace folly {
namespace coro {
TEST(GeneratorTest, DefaultConstructed_EmptySequence) {
Generator<std::uint32_t> ints;
EXPECT_EQ(ints.begin(), ints.end());
}
TEST(GeneratorTest, NonRecursiveUse) {
auto f = []() -> Generator<float> {
co_yield 1.0f;
co_yield 2.0f;
};
auto gen = f();
auto iter = gen.begin();
EXPECT_EQ(*iter, 1.0f);
++iter;
EXPECT_EQ(*iter, 2.0f);
++iter;
EXPECT_EQ(iter, gen.end());
}
TEST(GeneratorTest, ThrowsBeforeYieldingFirstElement_RethrowsFromBegin) {
class MyException : public std::exception {};
auto f = []() -> Generator<std::uint32_t> {
throw MyException{};
co_return;
};
auto gen = f();
EXPECT_THROW(gen.begin(), MyException);
}
TEST(GeneratorTest, ThrowsAfterYieldingFirstElement_RethrowsFromIncrement) {
class MyException : public std::exception {};
auto f = []() -> Generator<std::uint32_t> {
co_yield 1;
throw MyException{};
};
auto gen = f();
auto iter = gen.begin();
EXPECT_EQ(*iter, 1u);
EXPECT_THROW(++iter, MyException);
}
TEST(GeneratorTest, NotStartedUntilCalled) {
bool reachedA = false;
bool reachedB = false;
bool reachedC = false;
auto f = [&]() -> Generator<std::uint32_t> {
reachedA = true;
co_yield 1;
reachedB = true;
co_yield 2;
reachedC = true;
};
auto gen = f();
EXPECT_FALSE(reachedA);
auto iter = gen.begin();
EXPECT_TRUE(reachedA);
EXPECT_FALSE(reachedB);
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_TRUE(reachedB);
EXPECT_FALSE(reachedC);
EXPECT_EQ(*iter, 2u);
++iter;
EXPECT_TRUE(reachedC);
EXPECT_EQ(iter, gen.end());
}
TEST(GeneratorTest, DestroyedBeforeCompletion_DestructsObjectsOnStack) {
bool destructed = false;
bool completed = false;
auto f = [&]() -> Generator<std::uint32_t> {
SCOPE_EXIT {
destructed = true;
};
co_yield 1;
co_yield 2;
completed = true;
};
{
auto g = f();
auto it = g.begin();
auto itEnd = g.end();
EXPECT_NE(it, itEnd);
EXPECT_EQ(*it, 1u);
EXPECT_FALSE(destructed);
}
EXPECT_FALSE(completed);
EXPECT_TRUE(destructed);
}
TEST(GeneratorTest, SimpleRecursiveYield) {
auto f = [](int n, auto& f) -> Generator<const std::uint32_t> {
co_yield n;
if (n > 0) {
co_yield f(n - 1, f);
co_yield n;
}
};
auto f2 = [&f](int n) { return f(n, f); };
{
auto gen = f2(1);
auto iter = gen.begin();
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(*iter, 0u);
++iter;
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(iter, gen.end());
}
{
auto gen = f2(2);
auto iter = gen.begin();
EXPECT_EQ(*iter, 2u);
++iter;
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(*iter, 0u);
++iter;
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(*iter, 2u);
++iter;
EXPECT_EQ(iter, gen.end());
}
}
TEST(GeneratorTest, NestedEmptyYield) {
auto f = []() -> Generator<std::uint32_t> { co_return; };
auto g = [&f]() -> Generator<std::uint32_t> {
co_yield 1;
co_yield f();
co_yield 2;
};
auto gen = g();
auto iter = gen.begin();
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(*iter, 2u);
++iter;
EXPECT_EQ(iter, gen.end());
}
TEST(GeneratorTest, ExceptionThrownFromRecursiveCall_CanBeCaughtByCaller) {
class SomeException : public std::exception {};
auto f = [](std::uint32_t depth, auto&& f) -> Generator<std::uint32_t> {
if (depth == 1u) {
throw SomeException{};
}
co_yield 1;
try {
co_yield f(1, f);
} catch (const SomeException&) {
}
co_yield 2;
};
auto gen = f(0, f);
auto iter = gen.begin();
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(*iter, 2u);
++iter;
EXPECT_EQ(iter, gen.end());
}
TEST(GeneratorTest, ExceptionThrownFromNestedCall_CanBeCaughtByCaller) {
class SomeException : public std::exception {};
auto f = [](std::uint32_t depth, auto&& f) -> Generator<std::uint32_t> {
if (depth == 4u) {
throw SomeException{};
} else if (depth == 3u) {
co_yield 3;
try {
co_yield f(4, f);
} catch (const SomeException&) {
}
co_yield 33;
throw SomeException{};
} else if (depth == 2u) {
bool caught = false;
try {
co_yield f(3, f);
} catch (const SomeException&) {
caught = true;
}
if (caught) {
co_yield 2;
}
} else {
co_yield 1;
co_yield f(2, f);
co_yield f(3, f);
}
};
auto gen = f(1, f);
auto iter = gen.begin();
EXPECT_EQ(*iter, 1u);
++iter;
EXPECT_EQ(*iter, 3u);
++iter;
EXPECT_EQ(*iter, 33u);
++iter;
EXPECT_EQ(*iter, 2u);
++iter;
EXPECT_EQ(*iter, 3u);
++iter;
EXPECT_EQ(*iter, 33u);
EXPECT_THROW(++iter, SomeException);
EXPECT_EQ(iter, gen.end());
}
namespace {
Generator<std::uint32_t> iterate_range(std::uint32_t begin, std::uint32_t end) {
if ((end - begin) <= 10u) {
for (std::uint32_t i = begin; i < end; ++i) {
co_yield i;
}
} else {
std::uint32_t mid = begin + (end - begin) / 2;
co_yield iterate_range(begin, mid);
co_yield iterate_range(mid, end);
}
}
} // namespace
TEST(GeneratorTest, UsageInStandardAlgorithms) {
{
auto a = iterate_range(5, 30);
auto b = iterate_range(5, 30);
EXPECT_TRUE(std::equal(a.begin(), a.end(), b.begin(), b.end()));
}
{
auto a = iterate_range(5, 30);
auto b = iterate_range(5, 300);
EXPECT_FALSE(std::equal(a.begin(), a.end(), b.begin(), b.end()));
}
}
} // namespace coro
} // namespace folly
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment