Commit a4ece81e authored by Dan Melnic's avatar Dan Melnic Committed by Facebook Github Bot

Let ThreadLocalPtr reset() synchronize with Accessor

Summary: Let `ThreadLocalPtr` member `reset()` synchronize with member type `Accessor` lifetime, preventing data races that could happen if `reset()` is called while another thread is within the `Accessor` critical section.

Reviewed By: yfeldblum

Differential Revision: D14878801

fbshipit-source-id: 4090c5749d08bee3fd9b50b23df6e716bc78c9c7
parent 6eb8dace
...@@ -136,6 +136,8 @@ class ThreadLocalPtr { ...@@ -136,6 +136,8 @@ class ThreadLocalPtr {
private: private:
typedef threadlocal_detail::StaticMeta<Tag, AccessMode> StaticMeta; typedef threadlocal_detail::StaticMeta<Tag, AccessMode> StaticMeta;
using AccessAllThreadsEnabled = Negation<std::is_same<Tag, void>>;
public: public:
constexpr ThreadLocalPtr() : id_() {} constexpr ThreadLocalPtr() : id_() {}
...@@ -166,12 +168,16 @@ class ThreadLocalPtr { ...@@ -166,12 +168,16 @@ class ThreadLocalPtr {
} }
T* release() { T* release() {
auto rlock = getAccessAllThreadsLockReadHolderIfEnabled();
threadlocal_detail::ElementWrapper& w = StaticMeta::get(&id_); threadlocal_detail::ElementWrapper& w = StaticMeta::get(&id_);
return static_cast<T*>(w.release()); return static_cast<T*>(w.release());
} }
void reset(T* newPtr = nullptr) { void reset(T* newPtr = nullptr) {
auto rlock = getAccessAllThreadsLockReadHolderIfEnabled();
auto guard = makeGuard([&] { delete newPtr; }); auto guard = makeGuard([&] { delete newPtr; });
threadlocal_detail::ElementWrapper* w = &StaticMeta::get(&id_); threadlocal_detail::ElementWrapper* w = &StaticMeta::get(&id_);
...@@ -224,6 +230,8 @@ class ThreadLocalPtr { ...@@ -224,6 +230,8 @@ class ThreadLocalPtr {
*/ */
template <class Deleter> template <class Deleter>
void reset(T* newPtr, const Deleter& deleter) { void reset(T* newPtr, const Deleter& deleter) {
auto rlock = getAccessAllThreadsLockReadHolderIfEnabled();
auto guard = makeGuard([&] { auto guard = makeGuard([&] {
if (newPtr) { if (newPtr) {
deleter(newPtr, TLPDestructionMode::THIS_THREAD); deleter(newPtr, TLPDestructionMode::THIS_THREAD);
...@@ -434,7 +442,7 @@ class ThreadLocalPtr { ...@@ -434,7 +442,7 @@ class ThreadLocalPtr {
// elements of this ThreadLocal instance. Holds a global lock for each <Tag> // elements of this ThreadLocal instance. Holds a global lock for each <Tag>
Accessor accessAllThreads() const { Accessor accessAllThreads() const {
static_assert( static_assert(
!std::is_same<Tag, void>::value, AccessAllThreadsEnabled::value,
"Must use a unique Tag to use the accessAllThreads feature"); "Must use a unique Tag to use the accessAllThreads feature");
return Accessor(id_.getOrAllocate(StaticMeta::instance())); return Accessor(id_.getOrAllocate(StaticMeta::instance()));
} }
...@@ -448,6 +456,13 @@ class ThreadLocalPtr { ...@@ -448,6 +456,13 @@ class ThreadLocalPtr {
ThreadLocalPtr(const ThreadLocalPtr&) = delete; ThreadLocalPtr(const ThreadLocalPtr&) = delete;
ThreadLocalPtr& operator=(const ThreadLocalPtr&) = delete; ThreadLocalPtr& operator=(const ThreadLocalPtr&) = delete;
static auto getAccessAllThreadsLockReadHolderIfEnabled() {
return SharedMutex::ReadHolder(
AccessAllThreadsEnabled::value
? &StaticMeta::instance().accessAllThreadsLock_
: nullptr);
}
mutable typename StaticMeta::EntryID id_; mutable typename StaticMeta::EntryID id_;
}; };
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <boost/thread/barrier.hpp>
#include <glog/logging.h> #include <glog/logging.h>
#include <folly/Memory.h> #include <folly/Memory.h>
...@@ -522,6 +523,79 @@ TEST(ThreadLocal, Stress) { ...@@ -522,6 +523,79 @@ TEST(ThreadLocal, Stress) {
EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed); EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
} }
struct StressAccessTag {};
using TLPInt = ThreadLocalPtr<int, Tag>;
static void tlpIntCustomDeleter(int* p, TLPDestructionMode /*unused*/) {
delete p;
}
template <typename Op, typename Check>
void StresAccessTest(Op op, Check check) {
static constexpr size_t kNumThreads = 16;
static constexpr size_t kNumLoops = 10000;
TLPInt ptr;
ptr.reset(new int(0));
std::atomic<bool> running{true};
boost::barrier barrier(kNumThreads + 1);
std::vector<std::thread> threads;
for (size_t k = 0; k < kNumThreads; ++k) {
threads.emplace_back([&] {
ptr.reset(new int(1));
barrier.wait();
while (running.load()) {
op(ptr);
}
});
}
// wait for the threads to be up and running
barrier.wait();
for (size_t n = 0; n < kNumLoops; n++) {
int sum = 0;
auto accessor = ptr.accessAllThreads();
for (auto& i : accessor) {
sum += i;
}
check(sum, kNumThreads);
}
running.store(false);
for (auto& t : threads) {
t.join();
}
}
TEST(ThreadLocal, StressAccessReset) {
StresAccessTest(
[](TLPInt& ptr) { ptr.reset(new int(1)); },
[](size_t sum, size_t numThreads) { EXPECT_EQ(sum, numThreads); });
}
TEST(ThreadLocal, StressAccessResetDeleter) {
StresAccessTest(
[](TLPInt& ptr) { ptr.reset(new int(1), tlpIntCustomDeleter); },
[](size_t sum, size_t numThreads) { EXPECT_EQ(sum, numThreads); });
}
TEST(ThreadLocal, StressAccessRelease) {
StresAccessTest(
[](TLPInt& ptr) {
auto* p = ptr.release();
delete p;
ptr.reset(new int(1));
},
[](size_t sum, size_t numThreads) { EXPECT_LE(sum, numThreads); });
}
// Yes, threads and fork don't mix // Yes, threads and fork don't mix
// (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
// stupid or desperate enough to try, we shouldn't stand in your way. // stupid or desperate enough to try, we shouldn't stand in your way.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment