Commit 2ff21656 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook GitHub Bot

Make getWeakRef preserve SequencedExecutor tag

Reviewed By: yfeldblum

Differential Revision: D27253836

fbshipit-source-id: a72531e0deca9f0f6582339d483ed56f2eb37d87
parent 355fec90
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <folly/Executor.h> #include <folly/Executor.h>
#include <folly/executors/SequencedExecutor.h>
#include <folly/synchronization/Baton.h> #include <folly/synchronization/Baton.h>
namespace folly { namespace folly {
...@@ -31,10 +32,21 @@ class DefaultKeepAliveExecutor : public virtual Executor { ...@@ -31,10 +32,21 @@ class DefaultKeepAliveExecutor : public virtual Executor {
public: public:
virtual ~DefaultKeepAliveExecutor() { DCHECK(!keepAlive_); } virtual ~DefaultKeepAliveExecutor() { DCHECK(!keepAlive_); }
folly::Executor::KeepAlive<> weakRef() { template <typename ExecutorT>
return WeakRef::create(controlBlock_, this); static auto getWeakRef(ExecutorT& executor) {
static_assert(
std::is_base_of<DefaultKeepAliveExecutor, ExecutorT>::value,
"getWeakRef only works for folly::DefaultKeepAliveExecutor implementations.");
using WeakRefExecutorType = std::conditional_t<
std::is_base_of<SequencedExecutor, ExecutorT>::value,
SequencedExecutor,
Executor>;
return WeakRef<WeakRefExecutorType>::create(
executor.controlBlock_, &executor);
} }
folly::Executor::KeepAlive<> weakRef() { return getWeakRef(*this); }
protected: protected:
void joinKeepAlive() { void joinKeepAlive() {
DCHECK(keepAlive_); DCHECK(keepAlive_);
...@@ -56,10 +68,11 @@ class DefaultKeepAliveExecutor : public virtual Executor { ...@@ -56,10 +68,11 @@ class DefaultKeepAliveExecutor : public virtual Executor {
std::atomic<ssize_t> keepAliveCount_{1}; std::atomic<ssize_t> keepAliveCount_{1};
}; };
class WeakRef : public Executor { template <typename ExecutorT = Executor>
class WeakRef : public ExecutorT {
public: public:
static folly::Executor::KeepAlive<> create( static folly::Executor::KeepAlive<ExecutorT> create(
std::shared_ptr<ControlBlock> controlBlock, Executor* executor) { std::shared_ptr<ControlBlock> controlBlock, ExecutorT* executor) {
return makeKeepAlive(new WeakRef(std::move(controlBlock), executor)); return makeKeepAlive(new WeakRef(std::move(controlBlock), executor));
} }
...@@ -78,7 +91,7 @@ class DefaultKeepAliveExecutor : public virtual Executor { ...@@ -78,7 +91,7 @@ class DefaultKeepAliveExecutor : public virtual Executor {
virtual uint8_t getNumPriorities() const override { return numPriorities_; } virtual uint8_t getNumPriorities() const override { return numPriorities_; }
private: private:
WeakRef(std::shared_ptr<ControlBlock> controlBlock, Executor* executor) WeakRef(std::shared_ptr<ControlBlock> controlBlock, ExecutorT* executor)
: controlBlock_(std::move(controlBlock)), : controlBlock_(std::move(controlBlock)),
executor_(executor), executor_(executor),
numPriorities_(executor->getNumPriorities()) {} numPriorities_(executor->getNumPriorities()) {}
...@@ -101,7 +114,7 @@ class DefaultKeepAliveExecutor : public virtual Executor { ...@@ -101,7 +114,7 @@ class DefaultKeepAliveExecutor : public virtual Executor {
} }
} }
folly::Executor::KeepAlive<> lock() { folly::Executor::KeepAlive<ExecutorT> lock() {
auto controlBlock = auto controlBlock =
controlBlock_->keepAliveCount_.load(std::memory_order_relaxed); controlBlock_->keepAliveCount_.load(std::memory_order_relaxed);
do { do {
...@@ -114,13 +127,13 @@ class DefaultKeepAliveExecutor : public virtual Executor { ...@@ -114,13 +127,13 @@ class DefaultKeepAliveExecutor : public virtual Executor {
std::memory_order_release, std::memory_order_release,
std::memory_order_relaxed)); std::memory_order_relaxed));
return makeKeepAlive(executor_); return makeKeepAlive<ExecutorT>(executor_);
} }
std::atomic<size_t> keepAliveCount_{1}; std::atomic<size_t> keepAliveCount_{1};
std::shared_ptr<ControlBlock> controlBlock_; std::shared_ptr<ControlBlock> controlBlock_;
Executor* executor_; ExecutorT* executor_;
uint8_t numPriorities_; uint8_t numPriorities_;
}; };
...@@ -148,4 +161,12 @@ class DefaultKeepAliveExecutor : public virtual Executor { ...@@ -148,4 +161,12 @@ class DefaultKeepAliveExecutor : public virtual Executor {
KeepAlive<DefaultKeepAliveExecutor> keepAlive_{makeKeepAlive(this)}; KeepAlive<DefaultKeepAliveExecutor> keepAlive_{makeKeepAlive(this)};
}; };
template <typename ExecutorT>
auto getWeakRef(ExecutorT& executor) {
static_assert(
std::is_base_of<DefaultKeepAliveExecutor, ExecutorT>::value,
"getWeakRef only works for folly::DefaultKeepAliveExecutor implementations.");
return DefaultKeepAliveExecutor::getWeakRef(executor);
}
} // namespace folly } // namespace folly
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <folly/DefaultKeepAliveExecutor.h>
#include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/executors/ThreadPoolExecutor.h> #include <folly/executors/ThreadPoolExecutor.h>
#include <atomic> #include <atomic>
...@@ -934,7 +936,7 @@ static void WeakRefTest() { ...@@ -934,7 +936,7 @@ static void WeakRefTest() {
.via(&fe) .via(&fe)
.thenValue([](auto&&) { burnMs(100)(); }) .thenValue([](auto&&) { burnMs(100)(); })
.thenValue([&](auto&&) { ++counter; }) .thenValue([&](auto&&) { ++counter; })
.via(fe.weakRef()) .via(getWeakRef(fe))
.thenValue([](auto&&) { burnMs(100)(); }) .thenValue([](auto&&) { burnMs(100)(); })
.thenValue([&](auto&&) { ++counter; }); .thenValue([&](auto&&) { ++counter; });
} }
...@@ -984,6 +986,13 @@ static void virtualExecutorTest() { ...@@ -984,6 +986,13 @@ static void virtualExecutorTest() {
EXPECT_EQ(2, counter); EXPECT_EQ(2, counter);
} }
class SingleThreadedCPUThreadPoolExecutor : public CPUThreadPoolExecutor,
public SequencedExecutor {
public:
explicit SingleThreadedCPUThreadPoolExecutor(size_t)
: CPUThreadPoolExecutor(1) {}
};
TEST(ThreadPoolExecutorTest, WeakRefTestIO) { TEST(ThreadPoolExecutorTest, WeakRefTestIO) {
WeakRefTest<IOThreadPoolExecutor>(); WeakRefTest<IOThreadPoolExecutor>();
} }
...@@ -996,6 +1005,18 @@ TEST(ThreadPoolExecutorTest, WeakRefTestEDF) { ...@@ -996,6 +1005,18 @@ TEST(ThreadPoolExecutorTest, WeakRefTestEDF) {
WeakRefTest<EDFThreadPoolExecutor>(); WeakRefTest<EDFThreadPoolExecutor>();
} }
TEST(ThreadPoolExecutorTest, WeakRefTestSingleThreadedCPU) {
WeakRefTest<SingleThreadedCPUThreadPoolExecutor>();
}
TEST(ThreadPoolExecutorTest, WeakRefTestSequential) {
SingleThreadedCPUThreadPoolExecutor ex(1);
auto weakRef = getWeakRef(ex);
EXPECT_TRUE((std::is_same_v<
decltype(weakRef),
Executor::KeepAlive<SequencedExecutor>>));
}
TEST(ThreadPoolExecutorTest, VirtualExecutorTestIO) { TEST(ThreadPoolExecutorTest, VirtualExecutorTestIO) {
virtualExecutorTest<IOThreadPoolExecutor>(); virtualExecutorTest<IOThreadPoolExecutor>();
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment