Commit 9fc5d328 authored by Aaryaman Sagar's avatar Aaryaman Sagar Committed by Facebook Github Bot

Use folly::ThreadLocal in DeterministicSchedule

Summary:
DeterministicSchedule does not work on mobile because __thread or thread_local
support is not present.  folly::ThreadLocal helps because it uses
pthread_getspecific for mobile and platforms where support for this is not there

This does mean that we can't add DSched tests for folly::ThreadLocal, but that
does not have DSched tests anyway.  When we add DSched tests for that, we can
move to using a lock + map keyed by std::this_thread::get_id

Reviewed By: yfeldblum

Differential Revision: D15770455

fbshipit-source-id: 2a3cc6b3b1c116469cce6295a641d784e5bdfd50
parent 70d7eafa
...@@ -30,11 +30,6 @@ ...@@ -30,11 +30,6 @@
namespace folly { namespace folly {
namespace test { namespace test {
FOLLY_TLS sem_t* DeterministicSchedule::tls_sem;
FOLLY_TLS DeterministicSchedule* DeterministicSchedule::tls_sched;
FOLLY_TLS bool DeterministicSchedule::tls_exiting;
FOLLY_TLS DSchedThreadId DeterministicSchedule::tls_threadId;
thread_local AuxAct DeterministicSchedule::tls_aux_act;
AuxChk DeterministicSchedule::aux_chk; AuxChk DeterministicSchedule::aux_chk;
// access is protected by futexLock // access is protected by futexLock
...@@ -122,24 +117,27 @@ void ThreadSyncVar::acq_rel() { ...@@ -122,24 +117,27 @@ void ThreadSyncVar::acq_rel() {
DeterministicSchedule::DeterministicSchedule( DeterministicSchedule::DeterministicSchedule(
const std::function<size_t(size_t)>& scheduler) const std::function<size_t(size_t)>& scheduler)
: scheduler_(scheduler), nextThreadId_(0), step_(0) { : scheduler_(scheduler), nextThreadId_(0), step_(0) {
assert(tls_sem == nullptr); auto& tls = TLState::get();
assert(tls_sched == nullptr); assert(tls.sem == nullptr);
assert(tls_aux_act == nullptr); assert(tls.sched == nullptr);
assert(tls.aux_act == nullptr);
tls_exiting = false; tls.exiting = false;
tls_sem = new sem_t; tls.sem = new sem_t;
sem_init(tls_sem, 0, 1); sem_init(tls.sem, 0, 1);
sems_.push_back(tls_sem); sems_.push_back(tls.sem);
tls_threadId = nextThreadId_++; tls.threadId = nextThreadId_++;
threadInfoMap_.emplace_back(tls_threadId); threadInfoMap_.emplace_back(tls.threadId);
tls_sched = this; tls.sched = this;
} }
DeterministicSchedule::~DeterministicSchedule() { DeterministicSchedule::~DeterministicSchedule() {
assert(tls_sched == this); auto& tls = TLState::get();
static_cast<void>(tls);
assert(tls.sched == this);
assert(sems_.size() == 1); assert(sems_.size() == 1);
assert(sems_[0] == tls_sem); assert(sems_[0] == tls.sem);
beforeThreadExit(); beforeThreadExit();
} }
...@@ -207,13 +205,15 @@ DeterministicSchedule::uniformSubset(uint64_t seed, size_t n, size_t m) { ...@@ -207,13 +205,15 @@ DeterministicSchedule::uniformSubset(uint64_t seed, size_t n, size_t m) {
} }
void DeterministicSchedule::beforeSharedAccess() { void DeterministicSchedule::beforeSharedAccess() {
if (tls_sem) { auto& tls = TLState::get();
sem_wait(tls_sem); if (tls.sem) {
sem_wait(tls.sem);
} }
} }
void DeterministicSchedule::afterSharedAccess() { void DeterministicSchedule::afterSharedAccess() {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
if (!sched) { if (!sched) {
return; return;
} }
...@@ -221,7 +221,8 @@ void DeterministicSchedule::afterSharedAccess() { ...@@ -221,7 +221,8 @@ void DeterministicSchedule::afterSharedAccess() {
} }
void DeterministicSchedule::afterSharedAccess(bool success) { void DeterministicSchedule::afterSharedAccess(bool success) {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
if (!sched) { if (!sched) {
return; return;
} }
...@@ -230,8 +231,9 @@ void DeterministicSchedule::afterSharedAccess(bool success) { ...@@ -230,8 +231,9 @@ void DeterministicSchedule::afterSharedAccess(bool success) {
} }
size_t DeterministicSchedule::getRandNumber(size_t n) { size_t DeterministicSchedule::getRandNumber(size_t n) {
if (tls_sched) { auto& tls = TLState::get();
return tls_sched->scheduler_(n); if (tls.sched) {
return tls.sched->scheduler_(n);
} }
return Random::rand32() % n; return Random::rand32() % n;
} }
...@@ -240,17 +242,19 @@ int DeterministicSchedule::getcpu( ...@@ -240,17 +242,19 @@ int DeterministicSchedule::getcpu(
unsigned* cpu, unsigned* cpu,
unsigned* node, unsigned* node,
void* /* unused */) { void* /* unused */) {
auto& tls = TLState::get();
if (cpu) { if (cpu) {
*cpu = tls_threadId.val; *cpu = tls.threadId.val;
} }
if (node) { if (node) {
*node = tls_threadId.val; *node = tls.threadId.val;
} }
return 0; return 0;
} }
void DeterministicSchedule::setAuxAct(AuxAct& aux) { void DeterministicSchedule::setAuxAct(AuxAct& aux) {
tls_aux_act = aux; auto& tls = TLState::get();
tls.aux_act = aux;
} }
void DeterministicSchedule::setAuxChk(AuxChk& aux) { void DeterministicSchedule::setAuxChk(AuxChk& aux) {
...@@ -262,19 +266,21 @@ void DeterministicSchedule::clearAuxChk() { ...@@ -262,19 +266,21 @@ void DeterministicSchedule::clearAuxChk() {
} }
void DeterministicSchedule::reschedule(sem_t* sem) { void DeterministicSchedule::reschedule(sem_t* sem) {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
if (sched) { if (sched) {
sched->sems_.push_back(sem); sched->sems_.push_back(sem);
} }
} }
sem_t* DeterministicSchedule::descheduleCurrentThread() { sem_t* DeterministicSchedule::descheduleCurrentThread() {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
if (sched) { if (sched) {
sched->sems_.erase( sched->sems_.erase(
std::find(sched->sems_.begin(), sched->sems_.end(), tls_sem)); std::find(sched->sems_.begin(), sched->sems_.end(), tls.sem));
} }
return tls_sem; return tls.sem;
} }
sem_t* DeterministicSchedule::beforeThreadCreate() { sem_t* DeterministicSchedule::beforeThreadCreate() {
...@@ -287,19 +293,20 @@ sem_t* DeterministicSchedule::beforeThreadCreate() { ...@@ -287,19 +293,20 @@ sem_t* DeterministicSchedule::beforeThreadCreate() {
} }
void DeterministicSchedule::afterThreadCreate(sem_t* sem) { void DeterministicSchedule::afterThreadCreate(sem_t* sem) {
assert(tls_sem == nullptr); auto& tls = TLState::get();
assert(tls_sched == nullptr); assert(tls.sem == nullptr);
tls_exiting = false; assert(tls.sched == nullptr);
tls_sem = sem; tls.exiting = false;
tls_sched = this; tls.sem = sem;
tls.sched = this;
bool started = false; bool started = false;
while (!started) { while (!started) {
beforeSharedAccess(); beforeSharedAccess();
if (active_.count(std::this_thread::get_id()) == 1) { if (active_.count(std::this_thread::get_id()) == 1) {
started = true; started = true;
tls_threadId = nextThreadId_++; tls.threadId = nextThreadId_++;
assert(tls_threadId.val == threadInfoMap_.size()); assert(tls.threadId.val == threadInfoMap_.size());
threadInfoMap_.emplace_back(tls_threadId); threadInfoMap_.emplace_back(tls.threadId);
} }
afterSharedAccess(); afterSharedAccess();
} }
...@@ -307,7 +314,8 @@ void DeterministicSchedule::afterThreadCreate(sem_t* sem) { ...@@ -307,7 +314,8 @@ void DeterministicSchedule::afterThreadCreate(sem_t* sem) {
} }
void DeterministicSchedule::beforeThreadExit() { void DeterministicSchedule::beforeThreadExit() {
assert(tls_sched == this); auto& tls = TLState::get();
assert(tls.sched == this);
atomic_thread_fence(std::memory_order_seq_cst); atomic_thread_fence(std::memory_order_seq_cst);
beforeSharedAccess(); beforeSharedAccess();
...@@ -316,41 +324,43 @@ void DeterministicSchedule::beforeThreadExit() { ...@@ -316,41 +324,43 @@ void DeterministicSchedule::beforeThreadExit() {
reschedule(parent->second); reschedule(parent->second);
joins_.erase(parent); joins_.erase(parent);
} }
sems_.erase(std::find(sems_.begin(), sems_.end(), tls_sem)); sems_.erase(std::find(sems_.begin(), sems_.end(), tls.sem));
active_.erase(std::this_thread::get_id()); active_.erase(std::this_thread::get_id());
if (sems_.size() > 0) { if (sems_.size() > 0) {
FOLLY_TEST_DSCHED_VLOG("exiting"); FOLLY_TEST_DSCHED_VLOG("exiting");
/* Wait here so that parent thread can control when the thread /* Wait here so that parent thread can control when the thread
* enters the thread local destructors. */ * enters the thread local destructors. */
exitingSems_[std::this_thread::get_id()] = tls_sem; exitingSems_[std::this_thread::get_id()] = tls.sem;
afterSharedAccess(); afterSharedAccess();
sem_wait(tls_sem); sem_wait(tls.sem);
} }
tls_sched = nullptr; tls.sched = nullptr;
tls_aux_act = nullptr; tls.aux_act = nullptr;
tls_exiting = true; tls.exiting = true;
sem_destroy(tls_sem); sem_destroy(tls.sem);
delete tls_sem; delete tls.sem;
tls_sem = nullptr; tls.sem = nullptr;
} }
void DeterministicSchedule::waitForBeforeThreadExit(std::thread& child) { void DeterministicSchedule::waitForBeforeThreadExit(std::thread& child) {
assert(tls_sched == this); auto& tls = TLState::get();
assert(tls.sched == this);
beforeSharedAccess(); beforeSharedAccess();
assert(tls_sched->joins_.count(child.get_id()) == 0); assert(tls.sched->joins_.count(child.get_id()) == 0);
if (tls_sched->active_.count(child.get_id())) { if (tls.sched->active_.count(child.get_id())) {
sem_t* sem = descheduleCurrentThread(); sem_t* sem = descheduleCurrentThread();
tls_sched->joins_.insert({child.get_id(), sem}); tls.sched->joins_.insert({child.get_id(), sem});
afterSharedAccess(); afterSharedAccess();
// Wait to be scheduled by exiting child thread // Wait to be scheduled by exiting child thread
beforeSharedAccess(); beforeSharedAccess();
assert(!tls_sched->active_.count(child.get_id())); assert(!tls.sched->active_.count(child.get_id()));
} }
afterSharedAccess(); afterSharedAccess();
} }
void DeterministicSchedule::joinAll(std::vector<std::thread>& children) { void DeterministicSchedule::joinAll(std::vector<std::thread>& children) {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
if (sched) { if (sched) {
// Wait until all children are about to exit // Wait until all children are about to exit
for (auto& child : children) { for (auto& child : children) {
...@@ -369,7 +379,8 @@ void DeterministicSchedule::joinAll(std::vector<std::thread>& children) { ...@@ -369,7 +379,8 @@ void DeterministicSchedule::joinAll(std::vector<std::thread>& children) {
} }
void DeterministicSchedule::join(std::thread& child) { void DeterministicSchedule::join(std::thread& child) {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
if (sched) { if (sched) {
sched->waitForBeforeThreadExit(child); sched->waitForBeforeThreadExit(child);
} }
...@@ -382,10 +393,11 @@ void DeterministicSchedule::join(std::thread& child) { ...@@ -382,10 +393,11 @@ void DeterministicSchedule::join(std::thread& child) {
} }
void DeterministicSchedule::callAux(bool success) { void DeterministicSchedule::callAux(bool success) {
auto& tls = TLState::get();
++step_; ++step_;
if (tls_aux_act) { if (tls.aux_act) {
tls_aux_act(success); tls.aux_act(success);
tls_aux_act = nullptr; tls.aux_act = nullptr;
} }
if (aux_chk) { if (aux_chk) {
aux_chk(step_); aux_chk(step_);
...@@ -437,14 +449,16 @@ void DeterministicSchedule::wait(sem_t* sem) { ...@@ -437,14 +449,16 @@ void DeterministicSchedule::wait(sem_t* sem) {
} }
ThreadInfo& DeterministicSchedule::getCurrentThreadInfo() { ThreadInfo& DeterministicSchedule::getCurrentThreadInfo() {
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
assert(sched); assert(sched);
assert(tls_threadId.val < sched->threadInfoMap_.size()); assert(tls.threadId.val < sched->threadInfoMap_.size());
return sched->threadInfoMap_[tls_threadId.val]; return sched->threadInfoMap_[tls.threadId.val];
} }
void DeterministicSchedule::atomic_thread_fence(std::memory_order mo) { void DeterministicSchedule::atomic_thread_fence(std::memory_order mo) {
if (!tls_sched) { auto& tls = TLState::get();
if (!tls.sched) {
std::atomic_thread_fence(mo); std::atomic_thread_fence(mo);
return; return;
} }
...@@ -467,8 +481,8 @@ void DeterministicSchedule::atomic_thread_fence(std::memory_order mo) { ...@@ -467,8 +481,8 @@ void DeterministicSchedule::atomic_thread_fence(std::memory_order mo) {
break; break;
case std::memory_order_seq_cst: case std::memory_order_seq_cst:
threadInfo.acqRelOrder_.sync(threadInfo.acqFenceOrder_); threadInfo.acqRelOrder_.sync(threadInfo.acqFenceOrder_);
threadInfo.acqRelOrder_.sync(tls_sched->seqCstFenceOrder_); threadInfo.acqRelOrder_.sync(tls.sched->seqCstFenceOrder_);
tls_sched->seqCstFenceOrder_ = threadInfo.acqRelOrder_; tls.sched->seqCstFenceOrder_ = threadInfo.acqRelOrder_;
threadInfo.relFenceOrder_.sync(threadInfo.acqRelOrder_); threadInfo.relFenceOrder_.sync(threadInfo.acqRelOrder_);
break; break;
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <vector> #include <vector>
#include <folly/ScopeGuard.h> #include <folly/ScopeGuard.h>
#include <folly/SingletonThreadLocal.h>
#include <folly/concurrency/CacheLocality.h> #include <folly/concurrency/CacheLocality.h>
#include <folly/detail/Futex.h> #include <folly/detail/Futex.h>
#include <folly/portability/Semaphore.h> #include <folly/portability/Semaphore.h>
...@@ -197,7 +198,8 @@ class DeterministicSchedule { ...@@ -197,7 +198,8 @@ class DeterministicSchedule {
static inline std::thread thread(Func&& func, Args&&... args) { static inline std::thread thread(Func&& func, Args&&... args) {
// TODO: maybe future versions of gcc will allow forwarding to thread // TODO: maybe future versions of gcc will allow forwarding to thread
atomic_thread_fence(std::memory_order_seq_cst); atomic_thread_fence(std::memory_order_seq_cst);
auto sched = tls_sched; auto& tls = TLState::get();
auto sched = tls.sched;
auto sem = sched ? sched->beforeThreadCreate() : nullptr; auto sem = sched ? sched->beforeThreadCreate() : nullptr;
auto child = std::thread( auto child = std::thread(
[=](Args... a) { [=](Args... a) {
...@@ -276,19 +278,22 @@ class DeterministicSchedule { ...@@ -276,19 +278,22 @@ class DeterministicSchedule {
* the thread function, for example if the thread is executing * the thread function, for example if the thread is executing
* thread local destructors. */ * thread local destructors. */
static bool isCurrentThreadExiting() { static bool isCurrentThreadExiting() {
return tls_exiting; auto& tls = TLState::get();
return tls.exiting;
} }
/** Add sem back into sems_ */ /** Add sem back into sems_ */
static void reschedule(sem_t* sem); static void reschedule(sem_t* sem);
static bool isActive() { static bool isActive() {
return tls_sched != nullptr; auto& tls = TLState::get();
return tls.sched != nullptr;
} }
static DSchedThreadId getThreadId() { static DSchedThreadId getThreadId() {
assert(tls_sched != nullptr); auto& tls = TLState::get();
return tls_threadId; assert(tls.sched != nullptr);
return tls.threadId;
} }
static ThreadInfo& getCurrentThreadInfo(); static ThreadInfo& getCurrentThreadInfo();
...@@ -296,11 +301,25 @@ class DeterministicSchedule { ...@@ -296,11 +301,25 @@ class DeterministicSchedule {
static void atomic_thread_fence(std::memory_order mo); static void atomic_thread_fence(std::memory_order mo);
private: private:
static FOLLY_TLS sem_t* tls_sem; struct PerThreadState {
static FOLLY_TLS DeterministicSchedule* tls_sched; // delete the constructors and assignment operators for sanity
static FOLLY_TLS bool tls_exiting; //
static FOLLY_TLS DSchedThreadId tls_threadId; // but... we can't delete the move constructor and assignment operators
static thread_local AuxAct tls_aux_act; // because those are required before C++17 in the implementation of
// SingletonThreadLocal
PerThreadState(const PerThreadState&) = delete;
PerThreadState& operator=(const PerThreadState&) = delete;
PerThreadState(PerThreadState&&) = default;
PerThreadState& operator=(PerThreadState&&) = default;
PerThreadState() = default;
sem_t* sem{nullptr};
DeterministicSchedule* sched{nullptr};
bool exiting{false};
DSchedThreadId threadId{};
AuxAct aux_act{};
};
using TLState = SingletonThreadLocal<PerThreadState>;
static AuxChk aux_chk; static AuxChk aux_chk;
std::function<size_t(size_t)> scheduler_; std::function<size_t(size_t)> scheduler_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment