Commit 5d1f7e5f authored by Shai Szulanski's avatar Shai Szulanski Committed by Facebook GitHub Bot

CancellationToken::merge()

Summary:
Combining CancellationTokens is a common pattern in user code, which requires the user to manage a CancellationSource and attach a callback per token to be merged:
```
    folly::CancellationSource source;

    folly::CancellationCallbacki cb1{
        std::move(token1), [&] { source.requestCancellation(); }};
    folly::CancellationCallback cb2{
        std::move(token2), [&] { source.requestCancellation(); }};
    ...

    co_await co_withCancellation(source.getToken(), ...);
```

This diff adds the functionality to folly:
```
    co_await co_withCancellation(folly::CancellationToken::merge(token1, token2, ...), ...);
```

This diff subclasses CancellationState for simplicity. If the extra vtable ptr is a concern, we can instead template CancellationState to combine them, but that would come with a compilation speed penalty.

Reviewed By: yfeldblum

Differential Revision: D26277772

fbshipit-source-id: 169f65b4cf8f7f10f2800fb7797bb09379ecc239
parent de428b41
......@@ -21,20 +21,26 @@
#include <glog/logging.h>
#include <folly/Utility.h>
namespace folly {
namespace detail {
struct FixedMergingCancellationStateTag {};
// Internal cancellation state object.
class CancellationState {
public:
FOLLY_NODISCARD static CancellationStateSourcePtr create();
private:
protected:
// Constructed initially with a CancellationSource reference count of 1.
CancellationState() noexcept;
// Constructed initially with a CancellationToken reference count of 1.
explicit CancellationState(FixedMergingCancellationStateTag) noexcept;
~CancellationState();
virtual ~CancellationState();
friend struct CancellationStateTokenDeleter;
friend struct CancellationStateSourceDeleter;
......@@ -78,9 +84,10 @@ class CancellationState {
static constexpr std::uint64_t kCancellationRequestedFlag = 1;
static constexpr std::uint64_t kLockedFlag = 2;
static constexpr std::uint64_t kTokenReferenceCountIncrement = 4;
static constexpr std::uint64_t kMergingFlag = 4;
static constexpr std::uint64_t kTokenReferenceCountIncrement = 8;
static constexpr std::uint64_t kSourceReferenceCountIncrement =
std::uint64_t(1) << 33u;
std::uint64_t(1) << 34u;
static constexpr std::uint64_t kTokenReferenceCountMask =
(kSourceReferenceCountIncrement - 1u) -
(kTokenReferenceCountIncrement - 1u);
......@@ -90,13 +97,27 @@ class CancellationState {
// Bit 0 - Cancellation Requested
// Bit 1 - Locked Flag
// Bits 2-32 - Token reference count (max ~2 billion)
// Bits 33-63 - Source reference count (max ~2 billion)
// Bit 2 - MergingCancellationState Flag
// Bits 3-33 - Token reference count (max ~2 billion)
// Bits 34-63 - Source reference count (max ~1 billion)
std::atomic<std::uint64_t> state_;
CancellationCallback* head_;
CancellationCallback* head_{nullptr};
std::thread::id signallingThreadId_;
};
template <size_t N>
class FixedMergingCancellationState : public CancellationState {
template <typename... Ts>
FixedMergingCancellationState(Ts&&... tokens);
public:
template <typename... Ts>
FOLLY_NODISCARD static CancellationStateTokenPtr create(Ts&&... tokens);
private:
std::array<CancellationCallback, N> callbacks_;
};
inline void CancellationStateTokenDeleter::operator()(
CancellationState* state) noexcept {
state->removeTokenReference();
......@@ -285,9 +306,10 @@ inline CancellationStateSourcePtr CancellationState::create() {
}
inline CancellationState::CancellationState() noexcept
: state_(kSourceReferenceCountIncrement),
head_(nullptr),
signallingThreadId_() {}
: state_(kSourceReferenceCountIncrement) {}
inline CancellationState::CancellationState(
FixedMergingCancellationStateTag) noexcept
: state_(kTokenReferenceCountIncrement | kMergingFlag) {}
inline CancellationStateTokenPtr
CancellationState::addTokenReference() noexcept {
......@@ -334,7 +356,7 @@ inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept {
// Can be cancelled if there is at least one CancellationSource ref-count
// or if cancellation has been requested.
return (state >= kSourceReferenceCountIncrement) ||
isCancellationRequested(state);
(state & kMergingFlag) != 0 || isCancellationRequested(state);
}
inline bool CancellationState::isCancellationRequested(
......@@ -346,6 +368,34 @@ inline bool CancellationState::isLocked(std::uint64_t state) noexcept {
return (state & kLockedFlag) != 0;
}
template <size_t N>
template <typename... Ts>
inline CancellationStateTokenPtr FixedMergingCancellationState<N>::create(
Ts&&... tokens) {
return CancellationStateTokenPtr{
new FixedMergingCancellationState<N>(std::forward<Ts>(tokens)...)};
}
template <size_t N>
template <typename... Ts>
inline FixedMergingCancellationState<N>::FixedMergingCancellationState(
Ts&&... tokens)
: CancellationState(FixedMergingCancellationStateTag{}),
callbacks_{
{{std::forward<Ts>(tokens), [this] { requestCancellation(); }}...}} {}
} // namespace detail
template <typename... Ts>
inline CancellationToken CancellationToken::merge(Ts&&... tokens) {
std::array<bool, sizeof...(Ts)> cancellable{tokens.canBeCancelled()...};
bool canBeCancelled =
std::any_of(cancellable.begin(), cancellable.end(), identity);
return canBeCancelled
? CancellationToken(
detail::FixedMergingCancellationState<sizeof...(Ts)>::create(
std::forward<Ts>(tokens)...))
: CancellationToken();
}
} // namespace folly
......@@ -95,6 +95,16 @@ class CancellationToken {
// if they know they can never be cancelled.
bool canBeCancelled() const noexcept;
// Obtain a CancellationToken linked to any number of other
// CancellationTokens.
//
// This token will have cancellation requested when any of the passed-in
// tokens do.
// This token is cancellable if any of the passed-in tokens are at the time of
// construction.
template <typename... Ts>
static CancellationToken merge(Ts&&... tokens);
void swap(CancellationToken& other) noexcept;
friend bool operator==(
......
......@@ -261,3 +261,34 @@ TEST(CancellationTokenTest, NonCancellableSource) {
CHECK(!src.isCancellationRequested());
CHECK(token == CancellationToken{});
}
TEST(CancellationTokenTest, MergedToken) {
CancellationSource src1, src2;
auto token = CancellationToken::merge(src1.getToken(), src2.getToken());
EXPECT_TRUE(token.canBeCancelled());
EXPECT_FALSE(token.isCancellationRequested());
bool callbackExecuted = false;
CancellationCallback cb{token, [&] { callbackExecuted = true; }};
EXPECT_FALSE(callbackExecuted);
EXPECT_FALSE(token.isCancellationRequested());
src1.requestCancellation();
EXPECT_TRUE(callbackExecuted);
EXPECT_TRUE(token.isCancellationRequested());
src2.requestCancellation();
EXPECT_TRUE(callbackExecuted);
EXPECT_TRUE(token.isCancellationRequested());
token = CancellationToken::merge();
EXPECT_FALSE(token.canBeCancelled());
token = CancellationToken::merge(CancellationToken());
EXPECT_FALSE(token.canBeCancelled());
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment