Commit 5d1f7e5f authored by Shai Szulanski's avatar Shai Szulanski Committed by Facebook GitHub Bot

CancellationToken::merge()

Summary:
Combining CancellationTokens is a common pattern in user code, which requires the user to manage a CancellationSource and attach a callback per token to be merged:
```
    folly::CancellationSource source;

    folly::CancellationCallbacki cb1{
        std::move(token1), [&] { source.requestCancellation(); }};
    folly::CancellationCallback cb2{
        std::move(token2), [&] { source.requestCancellation(); }};
    ...

    co_await co_withCancellation(source.getToken(), ...);
```

This diff adds the functionality to folly:
```
    co_await co_withCancellation(folly::CancellationToken::merge(token1, token2, ...), ...);
```

This diff subclasses CancellationState for simplicity. If the extra vtable ptr is a concern, we can instead template CancellationState to combine them, but that would come with a compilation speed penalty.

Reviewed By: yfeldblum

Differential Revision: D26277772

fbshipit-source-id: 169f65b4cf8f7f10f2800fb7797bb09379ecc239
parent de428b41
...@@ -21,20 +21,26 @@ ...@@ -21,20 +21,26 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <folly/Utility.h>
namespace folly { namespace folly {
namespace detail { namespace detail {
struct FixedMergingCancellationStateTag {};
// Internal cancellation state object. // Internal cancellation state object.
class CancellationState { class CancellationState {
public: public:
FOLLY_NODISCARD static CancellationStateSourcePtr create(); FOLLY_NODISCARD static CancellationStateSourcePtr create();
private: protected:
// Constructed initially with a CancellationSource reference count of 1. // Constructed initially with a CancellationSource reference count of 1.
CancellationState() noexcept; CancellationState() noexcept;
// Constructed initially with a CancellationToken reference count of 1.
explicit CancellationState(FixedMergingCancellationStateTag) noexcept;
~CancellationState(); virtual ~CancellationState();
friend struct CancellationStateTokenDeleter; friend struct CancellationStateTokenDeleter;
friend struct CancellationStateSourceDeleter; friend struct CancellationStateSourceDeleter;
...@@ -78,9 +84,10 @@ class CancellationState { ...@@ -78,9 +84,10 @@ class CancellationState {
static constexpr std::uint64_t kCancellationRequestedFlag = 1; static constexpr std::uint64_t kCancellationRequestedFlag = 1;
static constexpr std::uint64_t kLockedFlag = 2; static constexpr std::uint64_t kLockedFlag = 2;
static constexpr std::uint64_t kTokenReferenceCountIncrement = 4; static constexpr std::uint64_t kMergingFlag = 4;
static constexpr std::uint64_t kTokenReferenceCountIncrement = 8;
static constexpr std::uint64_t kSourceReferenceCountIncrement = static constexpr std::uint64_t kSourceReferenceCountIncrement =
std::uint64_t(1) << 33u; std::uint64_t(1) << 34u;
static constexpr std::uint64_t kTokenReferenceCountMask = static constexpr std::uint64_t kTokenReferenceCountMask =
(kSourceReferenceCountIncrement - 1u) - (kSourceReferenceCountIncrement - 1u) -
(kTokenReferenceCountIncrement - 1u); (kTokenReferenceCountIncrement - 1u);
...@@ -90,13 +97,27 @@ class CancellationState { ...@@ -90,13 +97,27 @@ class CancellationState {
// Bit 0 - Cancellation Requested // Bit 0 - Cancellation Requested
// Bit 1 - Locked Flag // Bit 1 - Locked Flag
// Bits 2-32 - Token reference count (max ~2 billion) // Bit 2 - MergingCancellationState Flag
// Bits 33-63 - Source reference count (max ~2 billion) // Bits 3-33 - Token reference count (max ~2 billion)
// Bits 34-63 - Source reference count (max ~1 billion)
std::atomic<std::uint64_t> state_; std::atomic<std::uint64_t> state_;
CancellationCallback* head_; CancellationCallback* head_{nullptr};
std::thread::id signallingThreadId_; std::thread::id signallingThreadId_;
}; };
template <size_t N>
class FixedMergingCancellationState : public CancellationState {
template <typename... Ts>
FixedMergingCancellationState(Ts&&... tokens);
public:
template <typename... Ts>
FOLLY_NODISCARD static CancellationStateTokenPtr create(Ts&&... tokens);
private:
std::array<CancellationCallback, N> callbacks_;
};
inline void CancellationStateTokenDeleter::operator()( inline void CancellationStateTokenDeleter::operator()(
CancellationState* state) noexcept { CancellationState* state) noexcept {
state->removeTokenReference(); state->removeTokenReference();
...@@ -285,9 +306,10 @@ inline CancellationStateSourcePtr CancellationState::create() { ...@@ -285,9 +306,10 @@ inline CancellationStateSourcePtr CancellationState::create() {
} }
inline CancellationState::CancellationState() noexcept inline CancellationState::CancellationState() noexcept
: state_(kSourceReferenceCountIncrement), : state_(kSourceReferenceCountIncrement) {}
head_(nullptr), inline CancellationState::CancellationState(
signallingThreadId_() {} FixedMergingCancellationStateTag) noexcept
: state_(kTokenReferenceCountIncrement | kMergingFlag) {}
inline CancellationStateTokenPtr inline CancellationStateTokenPtr
CancellationState::addTokenReference() noexcept { CancellationState::addTokenReference() noexcept {
...@@ -334,7 +356,7 @@ inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept { ...@@ -334,7 +356,7 @@ inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept {
// Can be cancelled if there is at least one CancellationSource ref-count // Can be cancelled if there is at least one CancellationSource ref-count
// or if cancellation has been requested. // or if cancellation has been requested.
return (state >= kSourceReferenceCountIncrement) || return (state >= kSourceReferenceCountIncrement) ||
isCancellationRequested(state); (state & kMergingFlag) != 0 || isCancellationRequested(state);
} }
inline bool CancellationState::isCancellationRequested( inline bool CancellationState::isCancellationRequested(
...@@ -346,6 +368,34 @@ inline bool CancellationState::isLocked(std::uint64_t state) noexcept { ...@@ -346,6 +368,34 @@ inline bool CancellationState::isLocked(std::uint64_t state) noexcept {
return (state & kLockedFlag) != 0; return (state & kLockedFlag) != 0;
} }
template <size_t N>
template <typename... Ts>
inline CancellationStateTokenPtr FixedMergingCancellationState<N>::create(
Ts&&... tokens) {
return CancellationStateTokenPtr{
new FixedMergingCancellationState<N>(std::forward<Ts>(tokens)...)};
}
template <size_t N>
template <typename... Ts>
inline FixedMergingCancellationState<N>::FixedMergingCancellationState(
Ts&&... tokens)
: CancellationState(FixedMergingCancellationStateTag{}),
callbacks_{
{{std::forward<Ts>(tokens), [this] { requestCancellation(); }}...}} {}
} // namespace detail } // namespace detail
template <typename... Ts>
inline CancellationToken CancellationToken::merge(Ts&&... tokens) {
std::array<bool, sizeof...(Ts)> cancellable{tokens.canBeCancelled()...};
bool canBeCancelled =
std::any_of(cancellable.begin(), cancellable.end(), identity);
return canBeCancelled
? CancellationToken(
detail::FixedMergingCancellationState<sizeof...(Ts)>::create(
std::forward<Ts>(tokens)...))
: CancellationToken();
}
} // namespace folly } // namespace folly
...@@ -95,6 +95,16 @@ class CancellationToken { ...@@ -95,6 +95,16 @@ class CancellationToken {
// if they know they can never be cancelled. // if they know they can never be cancelled.
bool canBeCancelled() const noexcept; bool canBeCancelled() const noexcept;
// Obtain a CancellationToken linked to any number of other
// CancellationTokens.
//
// This token will have cancellation requested when any of the passed-in
// tokens do.
// This token is cancellable if any of the passed-in tokens are at the time of
// construction.
template <typename... Ts>
static CancellationToken merge(Ts&&... tokens);
void swap(CancellationToken& other) noexcept; void swap(CancellationToken& other) noexcept;
friend bool operator==( friend bool operator==(
......
...@@ -261,3 +261,34 @@ TEST(CancellationTokenTest, NonCancellableSource) { ...@@ -261,3 +261,34 @@ TEST(CancellationTokenTest, NonCancellableSource) {
CHECK(!src.isCancellationRequested()); CHECK(!src.isCancellationRequested());
CHECK(token == CancellationToken{}); CHECK(token == CancellationToken{});
} }
TEST(CancellationTokenTest, MergedToken) {
CancellationSource src1, src2;
auto token = CancellationToken::merge(src1.getToken(), src2.getToken());
EXPECT_TRUE(token.canBeCancelled());
EXPECT_FALSE(token.isCancellationRequested());
bool callbackExecuted = false;
CancellationCallback cb{token, [&] { callbackExecuted = true; }};
EXPECT_FALSE(callbackExecuted);
EXPECT_FALSE(token.isCancellationRequested());
src1.requestCancellation();
EXPECT_TRUE(callbackExecuted);
EXPECT_TRUE(token.isCancellationRequested());
src2.requestCancellation();
EXPECT_TRUE(callbackExecuted);
EXPECT_TRUE(token.isCancellationRequested());
token = CancellationToken::merge();
EXPECT_FALSE(token.canBeCancelled());
token = CancellationToken::merge(CancellationToken());
EXPECT_FALSE(token.canBeCancelled());
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment