Commit 55d6a804 authored by Akrama Baig Mirza's avatar Akrama Baig Mirza Committed by Facebook GitHub Bot

DRY implementation of ThreadIdCollector

Summary: ThreadManager.cpp and CPUThreadPoolExecutor.cpp both define the same implemenation of a ThreadIdCollector. Let's create one implementation that can be shared.

Reviewed By: yfeldblum, amlannayak

Differential Revision: D33065836

fbshipit-source-id: b17ef7abbba29cc7bad3779f20c76a3b08773963
parent fb0be74d
...@@ -41,37 +41,6 @@ namespace { ...@@ -41,37 +41,6 @@ namespace {
using default_queue = UnboundedBlockingQueue<CPUThreadPoolExecutor::CPUTask>; using default_queue = UnboundedBlockingQueue<CPUThreadPoolExecutor::CPUTask>;
using default_queue_alloc = using default_queue_alloc =
AlignedSysAllocator<default_queue, FixedAlign<alignof(default_queue)>>; AlignedSysAllocator<default_queue, FixedAlign<alignof(default_queue)>>;
class ThreadIdCollector : public WorkerProvider {
public:
ThreadIdCollector() {}
IdsWithKeepAlive collectThreadIds() override final {
auto keepAlive = std::make_unique<WorkerKeepAlive>(
SharedMutex::ReadHolder{&threadsExitMutex_});
auto locked = osThreadIds_.rlock();
return {std::move(keepAlive), {locked->begin(), locked->end()}};
}
Synchronized<std::unordered_set<pid_t>> osThreadIds_;
SharedMutex threadsExitMutex_;
private:
class WorkerKeepAlive : public WorkerProvider::KeepAlive {
public:
explicit WorkerKeepAlive(SharedMutex::ReadHolder idsLock)
: threadsExitLock_(std::move(idsLock)) {}
~WorkerKeepAlive() override {}
private:
SharedMutex::ReadHolder threadsExitLock_;
};
};
inline ThreadIdCollector* upcast(std::unique_ptr<WorkerProvider>& wpPtr) {
return static_cast<ThreadIdCollector*>(wpPtr.get());
}
} // namespace } // namespace
const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 14; const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 14;
...@@ -303,14 +272,12 @@ void CPUThreadPoolExecutor::threadRun(ThreadPtr thread) { ...@@ -303,14 +272,12 @@ void CPUThreadPoolExecutor::threadRun(ThreadPtr thread) {
} }
thread->startupBaton.post(); thread->startupBaton.post();
auto collectorPtr = upcast(threadIdCollector_); threadIdCollector_->addTid(folly::getOSThreadID());
collectorPtr->osThreadIds_.wlock()->insert(folly::getOSThreadID());
// On thread exit, we should remove the thread ID from the tracking list. // On thread exit, we should remove the thread ID from the tracking list.
auto threadIDsGuard = folly::makeGuard([collectorPtr]() { auto threadIDsGuard = folly::makeGuard([this]() {
// The observer could be capturing a stack trace from this thread // The observer could be capturing a stack trace from this thread
// so it should block until the collection finishes to exit. // so it should block until the collection finishes to exit.
collectorPtr->osThreadIds_.wlock()->erase(folly::getOSThreadID()); threadIdCollector_->removeTid(folly::getOSThreadID());
SharedMutex::WriteHolder w{collectorPtr->threadsExitMutex_};
}); });
while (true) { while (true) {
auto task = taskQueue_->try_take_for(threadTimeout_); auto task = taskQueue_->try_take_for(threadTimeout_);
...@@ -371,8 +338,4 @@ CPUThreadPoolExecutor::createQueueObserverFactory() { ...@@ -371,8 +338,4 @@ CPUThreadPoolExecutor::createQueueObserverFactory() {
threadIdCollector_.get()); threadIdCollector_.get());
} }
std::unique_ptr<WorkerProvider> CPUThreadPoolExecutor::createWorkerProvider() {
return std::make_unique<ThreadIdCollector>();
}
} // namespace folly } // namespace folly
...@@ -175,8 +175,8 @@ class CPUThreadPoolExecutor : public ThreadPoolExecutor { ...@@ -175,8 +175,8 @@ class CPUThreadPoolExecutor : public ThreadPoolExecutor {
protected: protected:
BlockingQueue<CPUTask>* getTaskQueue(); BlockingQueue<CPUTask>* getTaskQueue();
std::unique_ptr<WorkerProvider> createWorkerProvider(); std::unique_ptr<ThreadIdWorkerProvider> threadIdCollector_{
std::unique_ptr<WorkerProvider> threadIdCollector_{createWorkerProvider()}; std::make_unique<ThreadIdWorkerProvider>()};
private: private:
void threadRun(ThreadPtr thread) override; void threadRun(ThreadPtr thread) override;
......
...@@ -24,10 +24,38 @@ make_queue_observer_factory_fallback( ...@@ -24,10 +24,38 @@ make_queue_observer_factory_fallback(
return std::unique_ptr<folly::QueueObserverFactory>(); return std::unique_ptr<folly::QueueObserverFactory>();
} }
class WorkerKeepAlive : public folly::WorkerProvider::KeepAlive {
public:
explicit WorkerKeepAlive(folly::SharedMutex::ReadHolder idsLock)
: threadsExitLock_(std::move(idsLock)) {}
~WorkerKeepAlive() override {}
private:
folly::SharedMutex::ReadHolder threadsExitLock_;
};
} // namespace } // namespace
namespace folly { namespace folly {
ThreadIdWorkerProvider::IdsWithKeepAlive
ThreadIdWorkerProvider::collectThreadIds() {
auto keepAlive = std::make_unique<WorkerKeepAlive>(
SharedMutex::ReadHolder{&threadsExitMutex_});
auto locked = osThreadIds_.rlock();
return {std::move(keepAlive), {locked->begin(), locked->end()}};
}
void ThreadIdWorkerProvider::addTid(pid_t tid) {
osThreadIds_.wlock()->insert(tid);
}
void ThreadIdWorkerProvider::removeTid(pid_t tid) {
osThreadIds_.wlock()->erase(tid);
// block until all WorkerKeepAlives have been destroyed
SharedMutex::WriteHolder w{threadsExitMutex_};
}
WorkerProvider::KeepAlive::~KeepAlive() {} WorkerProvider::KeepAlive::~KeepAlive() {}
/* static */ std::unique_ptr<QueueObserverFactory> QueueObserverFactory::make( /* static */ std::unique_ptr<QueueObserverFactory> QueueObserverFactory::make(
......
...@@ -19,9 +19,12 @@ ...@@ -19,9 +19,12 @@
#include <stdint.h> #include <stdint.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include <folly/Portability.h> #include <folly/Portability.h>
#include <folly/Synchronized.h>
#include <folly/portability/SysTypes.h>
namespace folly { namespace folly {
...@@ -60,6 +63,19 @@ class WorkerProvider { ...@@ -60,6 +63,19 @@ class WorkerProvider {
virtual IdsWithKeepAlive collectThreadIds() = 0; virtual IdsWithKeepAlive collectThreadIds() = 0;
}; };
class ThreadIdWorkerProvider : public WorkerProvider {
public:
IdsWithKeepAlive collectThreadIds() override final;
void addTid(pid_t tid);
// Will block until all KeepAlives have been destroyed, if any exist
void removeTid(pid_t tid);
private:
Synchronized<std::unordered_set<pid_t>> osThreadIds_;
SharedMutex threadsExitMutex_;
};
class QueueObserver { class QueueObserver {
public: public:
virtual ~QueueObserver() {} virtual ~QueueObserver() {}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment