Commit 8d2d67e3 authored by Yedidya Feldblum's avatar Yedidya Feldblum Committed by Facebook Github Bot

Split KeepAlive concepts of dummy and alias

Summary: [Folly] Split KeepAlive concepts of dummy and alias. A dummy KeepAlive is one for an executor which does not actually support keep-alive semantics. An alias KeepAlive is one for which there is another KeepAlive, with a surrounding lifetime, to the same executor.

Reviewed By: andrewcox

Differential Revision: D15683241

fbshipit-source-id: a5809b06c90ed4a655a6973fac67137b5e1981dc
parent 86b2ff29
...@@ -73,8 +73,7 @@ class Executor { ...@@ -73,8 +73,7 @@ class Executor {
} }
KeepAlive(KeepAlive&& other) noexcept KeepAlive(KeepAlive&& other) noexcept
: executorAndDummyFlag_(std::exchange(other.executorAndDummyFlag_, 0)) { : storage_(std::exchange(other.storage_, 0)) {}
}
KeepAlive(const KeepAlive& other) noexcept KeepAlive(const KeepAlive& other) noexcept
: KeepAlive(getKeepAliveToken(other.get())) {} : KeepAlive(getKeepAliveToken(other.get())) {}
...@@ -84,8 +83,8 @@ class Executor { ...@@ -84,8 +83,8 @@ class Executor {
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_convertible<OtherExecutor*, ExecutorT*>::value>::type> std::is_convertible<OtherExecutor*, ExecutorT*>::value>::type>
/* implicit */ KeepAlive(KeepAlive<OtherExecutor>&& other) noexcept /* implicit */ KeepAlive(KeepAlive<OtherExecutor>&& other) noexcept
: KeepAlive(other.get(), other.executorAndDummyFlag_ & kDummyFlag) { : KeepAlive(other.get(), other.storage_ & kFlagMask) {
other.executorAndDummyFlag_ = 0; other.storage_ = 0;
} }
template < template <
...@@ -101,7 +100,7 @@ class Executor { ...@@ -101,7 +100,7 @@ class Executor {
KeepAlive& operator=(KeepAlive&& other) { KeepAlive& operator=(KeepAlive&& other) {
reset(); reset();
executorAndDummyFlag_ = std::exchange(other.executorAndDummyFlag_, 0); storage_ = std::exchange(other.storage_, 0);
return *this; return *this;
} }
...@@ -123,20 +122,19 @@ class Executor { ...@@ -123,20 +122,19 @@ class Executor {
void reset() { void reset() {
if (Executor* executor = get()) { if (Executor* executor = get()) {
if (std::exchange(executorAndDummyFlag_, 0) & kDummyFlag) { auto const flags = std::exchange(storage_, 0) & kFlagMask;
return; if (!(flags & (kDummyFlag | kAliasFlag))) {
executor->keepAliveRelease();
} }
executor->keepAliveRelease();
} }
} }
explicit operator bool() const { explicit operator bool() const {
return executorAndDummyFlag_; return storage_;
} }
ExecutorT* get() const { ExecutorT* get() const {
return reinterpret_cast<ExecutorT*>( return reinterpret_cast<ExecutorT*>(storage_ & kExecutorMask);
executorAndDummyFlag_ & kExecutorMask);
} }
ExecutorT& operator*() const { ExecutorT& operator*() const {
...@@ -151,31 +149,38 @@ class Executor { ...@@ -151,31 +149,38 @@ class Executor {
return getKeepAliveToken(get()); return getKeepAliveToken(get());
} }
// Creates a dummy copy of this KeepAlive token, which doesn't increment KeepAlive get_alias() const {
// the ref-count. Should only be used if this KeepAlive token is known to return KeepAlive(storage_ | kAliasFlag);
// outlive such dummy copy.
KeepAlive copyDummy() const {
return KeepAlive(get(), true);
} }
private: private:
static constexpr intptr_t kDummyFlag = 1; // A dummy keep-alive is a keep-alive to an executor which does not support
static constexpr intptr_t kExecutorMask = ~kDummyFlag; // the keep-alive mechanism.
static constexpr uintptr_t kDummyFlag = uintptr_t(1) << 0;
// An alias keep-alive is a keep-alive to an executor to which there is
// known to be another keep-alive whose lifetime surrounds the lifetime of
// the alias.
static constexpr uintptr_t kAliasFlag = uintptr_t(1) << 1;
static constexpr uintptr_t kFlagMask = kDummyFlag | kAliasFlag;
static constexpr uintptr_t kExecutorMask = ~kFlagMask;
friend class Executor; friend class Executor;
template <typename OtherExecutor> template <typename OtherExecutor>
friend class KeepAlive; friend class KeepAlive;
KeepAlive(ExecutorT* executor, bool dummy) KeepAlive(ExecutorT* executor, uintptr_t flags) noexcept
: executorAndDummyFlag_( : storage_(reinterpret_cast<uintptr_t>(executor) | flags) {
reinterpret_cast<intptr_t>(executor) | (dummy ? kDummyFlag : 0)) {
assert(executor); assert(executor);
assert( assert(!(reinterpret_cast<uintptr_t>(executor) & ~kExecutorMask));
(reinterpret_cast<intptr_t>(executor) & kExecutorMask) == assert(!(flags & kExecutorMask));
reinterpret_cast<intptr_t>(executor));
} }
intptr_t executorAndDummyFlag_{reinterpret_cast<intptr_t>(nullptr)}; explicit KeepAlive(uintptr_t storage) noexcept : storage_(storage) {}
// Combined storage for the executor pointer and for all flags.
uintptr_t storage_{reinterpret_cast<uintptr_t>(nullptr)};
}; };
template <typename ExecutorT> template <typename ExecutorT>
...@@ -208,7 +213,7 @@ class Executor { ...@@ -208,7 +213,7 @@ class Executor {
*/ */
template <typename ExecutorT> template <typename ExecutorT>
static bool isKeepAliveDummy(const KeepAlive<ExecutorT>& keepAlive) { static bool isKeepAliveDummy(const KeepAlive<ExecutorT>& keepAlive) {
return reinterpret_cast<intptr_t>(keepAlive.executorAndDummyFlag_) & return reinterpret_cast<uintptr_t>(keepAlive.storage_) &
KeepAlive<ExecutorT>::kDummyFlag; KeepAlive<ExecutorT>::kDummyFlag;
} }
...@@ -233,7 +238,7 @@ class Executor { ...@@ -233,7 +238,7 @@ class Executor {
static_assert( static_assert(
std::is_base_of<Executor, ExecutorT>::value, std::is_base_of<Executor, ExecutorT>::value,
"makeKeepAliveDummy only works for folly::Executor implementations."); "makeKeepAliveDummy only works for folly::Executor implementations.");
return KeepAlive<ExecutorT>{executor, true}; return KeepAlive<ExecutorT>{executor, KeepAlive<ExecutorT>::kDummyFlag};
} }
}; };
......
...@@ -108,7 +108,7 @@ class AsyncGeneratorPromise { ...@@ -108,7 +108,7 @@ class AsyncGeneratorPromise {
template <typename U> template <typename U>
auto await_transform(U&& value) { auto await_transform(U&& value) {
return folly::coro::co_viaIfAsync( return folly::coro::co_viaIfAsync(
executor_.copyDummy(), static_cast<U&&>(value)); executor_.get_alias(), static_cast<U&&>(value));
} }
auto await_transform(folly::coro::co_current_executor_t) noexcept { auto await_transform(folly::coro::co_current_executor_t) noexcept {
......
...@@ -194,7 +194,7 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -194,7 +194,7 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
auto makeTask = [&](awaitable_type semiAwaitable) -> detail::BarrierTask { auto makeTask = [&](awaitable_type semiAwaitable) -> detail::BarrierTask {
try { try {
co_await coro::co_viaIfAsync( co_await coro::co_viaIfAsync(
executor.copyDummy(), std::move(semiAwaitable)); executor.get_alias(), std::move(semiAwaitable));
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
if (!anyFailures.exchange(true, std::memory_order_relaxed)) { if (!anyFailures.exchange(true, std::memory_order_relaxed)) {
firstException = exception_wrapper{std::current_exception(), ex}; firstException = exception_wrapper{std::current_exception(), ex};
...@@ -253,11 +253,11 @@ auto collectAllTryRange(InputRange awaitables) ...@@ -253,11 +253,11 @@ auto collectAllTryRange(InputRange awaitables)
semi_await_result_t<detail::range_reference_t<InputRange>>; semi_await_result_t<detail::range_reference_t<InputRange>>;
if constexpr (std::is_void_v<await_result>) { if constexpr (std::is_void_v<await_result>) {
co_await coro::co_viaIfAsync( co_await coro::co_viaIfAsync(
executor.copyDummy(), std::move(semiAwaitable)); executor.get_alias(), std::move(semiAwaitable));
result.emplace(); result.emplace();
} else { } else {
result.emplace(co_await coro::co_viaIfAsync( result.emplace(co_await coro::co_viaIfAsync(
executor.copyDummy(), std::move(semiAwaitable))); executor.get_alias(), std::move(semiAwaitable)));
} }
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
result.emplaceException(std::current_exception(), ex); result.emplaceException(std::current_exception(), ex);
...@@ -324,7 +324,7 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -324,7 +324,7 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
auto makeWorker = [&]() -> detail::BarrierTask { auto makeWorker = [&]() -> detail::BarrierTask {
auto lock = auto lock =
co_await co_viaIfAsync(executor.copyDummy(), mutex.co_scoped_lock()); co_await co_viaIfAsync(executor.get_alias(), mutex.co_scoped_lock());
while (iter != iterEnd) { while (iter != iterEnd) {
awaitable_t awaitable = *iter; awaitable_t awaitable = *iter;
...@@ -341,13 +341,13 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -341,13 +341,13 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
std::exception_ptr ex; std::exception_ptr ex;
try { try {
co_await co_viaIfAsync( co_await co_viaIfAsync(
executor.copyDummy(), static_cast<awaitable_t&&>(awaitable)); executor.get_alias(), static_cast<awaitable_t&&>(awaitable));
} catch (...) { } catch (...) {
ex = std::current_exception(); ex = std::current_exception();
} }
lock = lock =
co_await co_viaIfAsync(executor.copyDummy(), mutex.co_scoped_lock()); co_await co_viaIfAsync(executor.get_alias(), mutex.co_scoped_lock());
if (ex && !firstException) { if (ex && !firstException) {
firstException = std::move(ex); firstException = std::move(ex);
...@@ -444,7 +444,7 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -444,7 +444,7 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency)
auto makeWorker = [&]() -> detail::BarrierTask { auto makeWorker = [&]() -> detail::BarrierTask {
auto lock = auto lock =
co_await co_viaIfAsync(executor.copyDummy(), mutex.co_scoped_lock()); co_await co_viaIfAsync(executor.get_alias(), mutex.co_scoped_lock());
while (!iterationException && iter != iterEnd) { while (!iterationException && iter != iterEnd) {
try { try {
...@@ -477,11 +477,11 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -477,11 +477,11 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency)
try { try {
if constexpr (std::is_void_v<result_t>) { if constexpr (std::is_void_v<result_t>) {
co_await co_viaIfAsync( co_await co_viaIfAsync(
executor.copyDummy(), static_cast<awaitable_t&&>(awaitable)); executor.get_alias(), static_cast<awaitable_t&&>(awaitable));
result.emplace(); result.emplace();
} else { } else {
result.emplace(co_await co_viaIfAsync( result.emplace(co_await co_viaIfAsync(
executor.copyDummy(), static_cast<awaitable_t&&>(awaitable))); executor.get_alias(), static_cast<awaitable_t&&>(awaitable)));
} }
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
result.emplaceException(std::current_exception(), ex); result.emplaceException(std::current_exception(), ex);
...@@ -490,7 +490,7 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -490,7 +490,7 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency)
} }
lock = co_await co_viaIfAsync( lock = co_await co_viaIfAsync(
executor.copyDummy(), mutex.co_scoped_lock()); executor.get_alias(), mutex.co_scoped_lock());
try { try {
results[thisIndex] = std::move(result); results[thisIndex] = std::move(result);
......
...@@ -78,7 +78,7 @@ class TaskPromiseBase { ...@@ -78,7 +78,7 @@ class TaskPromiseBase {
template <typename Awaitable> template <typename Awaitable>
auto await_transform(Awaitable&& awaitable) noexcept { auto await_transform(Awaitable&& awaitable) noexcept {
return folly::coro::co_viaIfAsync( return folly::coro::co_viaIfAsync(
executor_.copyDummy(), static_cast<Awaitable&&>(awaitable)); executor_.get_alias(), static_cast<Awaitable&&>(awaitable));
} }
auto await_transform(co_current_executor_t) noexcept { auto await_transform(co_current_executor_t) noexcept {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment