Commit a574e130 authored by Dave Watson's avatar Dave Watson Committed by Facebook Github Bot

Fix stop() race

Summary:
There is a race with stop() and timeouts.  stop() may race with a thread timeout, and then block indefinitely on
joinStoppedThreads(), waiting to join the already joined timed-out thread.

To fix, just put everything behind a write lock, simplifying the code at the same time.

Reviewed By: davidtgoldblatt

Differential Revision: D7971779

fbshipit-source-id: eb4a898ca266ddda0f95033950ac041d1cc812ba
parent c0a9ed05
...@@ -137,46 +137,40 @@ CPUThreadPoolExecutor::getTaskQueue() { ...@@ -137,46 +137,40 @@ CPUThreadPoolExecutor::getTaskQueue() {
return taskQueue_.get(); return taskQueue_.get();
} }
// threadListLock_ must be writelocked.
bool CPUThreadPoolExecutor::tryDecrToStop() { bool CPUThreadPoolExecutor::tryDecrToStop() {
while (true) { auto toStop = threadsToStop_.load(std::memory_order_relaxed);
auto toStop = threadsToStop_.load(std::memory_order_relaxed); if (toStop <= 0) {
if (toStop <= 0) { return false;
return false;
}
if (threadsToStop_.compare_exchange_strong(
toStop, toStop - 1, std::memory_order_relaxed)) {
return true;
}
} }
threadsToStop_.store(toStop - 1, std::memory_order_relaxed);
return true;
} }
bool CPUThreadPoolExecutor::taskShouldStop(folly::Optional<CPUTask>& task) { bool CPUThreadPoolExecutor::taskShouldStop(folly::Optional<CPUTask>& task) {
if (tryDecrToStop()) {
return true;
}
if (task) { if (task) {
if (!tryDecrToStop()) { return false;
// Some other thread beat us to it. } else {
// Try to stop based on idle thread timeout (try_take_for),
// if there are at least minThreads running.
if (!minActive()) {
return false; return false;
} }
} else { // If this is based on idle thread timeout, then
{ // adjust vars appropriately (otherwise stop() or join()
SharedMutex::WriteHolder w{&threadListLock_}; // does this).
// Try to stop based on idle thread timeout (try_take_for), if (getPendingTaskCountImpl() > 0) {
// if there are at least minThreads running. return false;
if (!minActive()) {
return false;
}
// If this is based on idle thread timeout, then
// adjust vars appropriately (otherwise stop() or join()
// does this).
if (getPendingTaskCountImpl() > 0) {
return false;
}
activeThreads_.store(
activeThreads_.load(std::memory_order_relaxed) - 1,
std::memory_order_relaxed);
threadsToJoin_.store(
threadsToJoin_.load(std::memory_order_relaxed) + 1,
std::memory_order_relaxed);
} }
activeThreads_.store(
activeThreads_.load(std::memory_order_relaxed) - 1,
std::memory_order_relaxed);
threadsToJoin_.store(
threadsToJoin_.load(std::memory_order_relaxed) + 1,
std::memory_order_relaxed);
} }
return true; return true;
} }
...@@ -190,12 +184,12 @@ void CPUThreadPoolExecutor::threadRun(ThreadPtr thread) { ...@@ -190,12 +184,12 @@ void CPUThreadPoolExecutor::threadRun(ThreadPtr thread) {
// Handle thread stopping, either by task timeout, or // Handle thread stopping, either by task timeout, or
// by 'poison' task added in join() or stop(). // by 'poison' task added in join() or stop().
if (UNLIKELY(!task || task.value().poison)) { if (UNLIKELY(!task || task.value().poison)) {
// Actually remove the thread from the list.
SharedMutex::WriteHolder w{&threadListLock_};
if (taskShouldStop(task)) { if (taskShouldStop(task)) {
for (auto& o : observers_) { for (auto& o : observers_) {
o->threadStopped(thread.get()); o->threadStopped(thread.get());
} }
// Actually remove the thread from the list.
SharedMutex::WriteHolder w{&threadListLock_};
threadList_.remove(thread); threadList_.remove(thread);
stoppedThreads_.add(thread); stoppedThreads_.add(thread);
return; return;
...@@ -207,8 +201,8 @@ void CPUThreadPoolExecutor::threadRun(ThreadPtr thread) { ...@@ -207,8 +201,8 @@ void CPUThreadPoolExecutor::threadRun(ThreadPtr thread) {
runTask(thread, std::move(task.value())); runTask(thread, std::move(task.value()));
if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) { if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) {
SharedMutex::WriteHolder w{&threadListLock_};
if (tryDecrToStop()) { if (tryDecrToStop()) {
SharedMutex::WriteHolder w{&threadListLock_};
threadList_.remove(thread); threadList_.remove(thread);
stoppedThreads_.add(thread); stoppedThreads_.add(thread);
return; return;
......
...@@ -207,18 +207,15 @@ void ThreadPoolExecutor::joinStoppedThreads(size_t n) { ...@@ -207,18 +207,15 @@ void ThreadPoolExecutor::joinStoppedThreads(size_t n) {
} }
void ThreadPoolExecutor::stop() { void ThreadPoolExecutor::stop() {
{
folly::SharedMutex::WriteHolder w{&threadListLock_};
maxThreads_.store(0, std::memory_order_release);
activeThreads_.store(0, std::memory_order_release);
}
ensureJoined();
size_t n = 0; size_t n = 0;
{ {
SharedMutex::WriteHolder w{&threadListLock_}; SharedMutex::WriteHolder w{&threadListLock_};
maxThreads_.store(0, std::memory_order_release);
activeThreads_.store(0, std::memory_order_release);
n = threadList_.get().size(); n = threadList_.get().size();
removeThreads(n, false); removeThreads(n, false);
n += threadsToJoin_.load(std::memory_order_relaxed);
threadsToJoin_.store(0, std::memory_order_relaxed);
} }
joinStoppedThreads(n); joinStoppedThreads(n);
CHECK_EQ(0, threadList_.get().size()); CHECK_EQ(0, threadList_.get().size());
...@@ -226,18 +223,15 @@ void ThreadPoolExecutor::stop() { ...@@ -226,18 +223,15 @@ void ThreadPoolExecutor::stop() {
} }
void ThreadPoolExecutor::join() { void ThreadPoolExecutor::join() {
{
folly::SharedMutex::WriteHolder w{&threadListLock_};
maxThreads_.store(0, std::memory_order_release);
activeThreads_.store(0, std::memory_order_release);
}
ensureJoined();
size_t n = 0; size_t n = 0;
{ {
SharedMutex::WriteHolder w{&threadListLock_}; SharedMutex::WriteHolder w{&threadListLock_};
maxThreads_.store(0, std::memory_order_release);
activeThreads_.store(0, std::memory_order_release);
n = threadList_.get().size(); n = threadList_.get().size();
removeThreads(n, true); removeThreads(n, true);
n += threadsToJoin_.load(std::memory_order_relaxed);
threadsToJoin_.store(std::memory_order_relaxed);
} }
joinStoppedThreads(n); joinStoppedThreads(n);
CHECK_EQ(0, threadList_.get().size()); CHECK_EQ(0, threadList_.get().size());
...@@ -378,7 +372,11 @@ void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) { ...@@ -378,7 +372,11 @@ void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) {
void ThreadPoolExecutor::ensureJoined() { void ThreadPoolExecutor::ensureJoined() {
auto tojoin = threadsToJoin_.load(std::memory_order_relaxed); auto tojoin = threadsToJoin_.load(std::memory_order_relaxed);
if (tojoin) { if (tojoin) {
tojoin = threadsToJoin_.exchange(0, std::memory_order_relaxed); {
SharedMutex::WriteHolder w{&threadListLock_};
tojoin = threadsToJoin_.load(std::memory_order_relaxed);
threadsToJoin_.store(0, std::memory_order_relaxed);
}
joinStoppedThreads(tojoin); joinStoppedThreads(tojoin);
} }
} }
......
...@@ -737,3 +737,15 @@ TEST(ThreadPoolExecutorTest, DynamicThreadAddRemoveRace) { ...@@ -737,3 +737,15 @@ TEST(ThreadPoolExecutorTest, DynamicThreadAddRemoveRace) {
e.join(); e.join();
EXPECT_EQ(count, 10000); EXPECT_EQ(count, 10000);
} }
TEST(ThreadPoolExecutorTest, AddPerf) {
CPUThreadPoolExecutor e(
1000,
std::make_unique<
UnboundedBlockingQueue<CPUThreadPoolExecutor::CPUTask>>());
e.setThreadDeathTimeout(std::chrono::milliseconds(1));
for (int i = 0; i < 10000; i++) {
e.add([&]() { e.add([]() { /* sleep override */ usleep(1000); }); });
}
e.stop();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment