Commit f8908e51 authored by Philip Pronin's avatar Philip Pronin Committed by Facebook Github Bot

Reverted commit D3755446

Summary:
Move ThreadLocal object destruction to occur under the lock to avoid races.
This causes a few cascading changes - the Tag lock needs to be a recursive_mutex so
constructing a new object while destroying another st. Also, forking requires
a new mutex to avoid deadlocking on accessing a recursive_mutex across a fork()

Reviewed By: andriigrynenko

Differential Revision: D3755446

fbshipit-source-id: f1f1f92175eb39e77aaa2add6915e5c9bb68d0fb
parent 8c3d0bc9
...@@ -63,11 +63,8 @@ class ThreadCachedInt : boost::noncopyable { ...@@ -63,11 +63,8 @@ class ThreadCachedInt : boost::noncopyable {
// Reads the current value plus all the cached increments. Requires grabbing // Reads the current value plus all the cached increments. Requires grabbing
// a lock, so this is significantly slower than readFast(). // a lock, so this is significantly slower than readFast().
IntT readFull() const { IntT readFull() const {
// This could race with thread destruction and so the access lock should be
// acquired before reading the current value
auto accessor = cache_.accessAllThreads();
IntT ret = readFast(); IntT ret = readFast();
for (const auto& cache : accessor) { for (const auto& cache : cache_.accessAllThreads()) {
if (!cache.reset_.load(std::memory_order_acquire)) { if (!cache.reset_.load(std::memory_order_acquire)) {
ret += cache.val_.load(std::memory_order_relaxed); ret += cache.val_.load(std::memory_order_relaxed);
} }
...@@ -85,11 +82,8 @@ class ThreadCachedInt : boost::noncopyable { ...@@ -85,11 +82,8 @@ class ThreadCachedInt : boost::noncopyable {
// little off, however, but it should be much better than calling readFull() // little off, however, but it should be much better than calling readFull()
// and set(0) sequentially. // and set(0) sequentially.
IntT readFullAndReset() { IntT readFullAndReset() {
// This could race with thread destruction and so the access lock should be
// acquired before reading the current value
auto accessor = cache_.accessAllThreads();
IntT ret = readFastAndReset(); IntT ret = readFastAndReset();
for (auto& cache : accessor) { for (auto& cache : cache_.accessAllThreads()) {
if (!cache.reset_.load(std::memory_order_acquire)) { if (!cache.reset_.load(std::memory_order_acquire)) {
ret += cache.val_.load(std::memory_order_relaxed); ret += cache.val_.load(std::memory_order_relaxed);
cache.reset_.store(true, std::memory_order_release); cache.reset_.store(true, std::memory_order_release);
......
...@@ -36,11 +36,10 @@ ...@@ -36,11 +36,10 @@
#pragma once #pragma once
#include <boost/iterator/iterator_facade.hpp>
#include <folly/Likely.h> #include <folly/Likely.h>
#include <folly/Portability.h> #include <folly/Portability.h>
#include <folly/ScopeGuard.h> #include <folly/ScopeGuard.h>
#include <folly/SharedMutex.h> #include <boost/iterator/iterator_facade.hpp>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
...@@ -250,7 +249,6 @@ class ThreadLocalPtr { ...@@ -250,7 +249,6 @@ class ThreadLocalPtr {
friend class ThreadLocalPtr<T,Tag>; friend class ThreadLocalPtr<T,Tag>;
threadlocal_detail::StaticMetaBase& meta_; threadlocal_detail::StaticMetaBase& meta_;
SharedMutex* accessAllThreadsLock_;
std::mutex* lock_; std::mutex* lock_;
uint32_t id_; uint32_t id_;
...@@ -324,11 +322,9 @@ class ThreadLocalPtr { ...@@ -324,11 +322,9 @@ class ThreadLocalPtr {
Accessor(Accessor&& other) noexcept Accessor(Accessor&& other) noexcept
: meta_(other.meta_), : meta_(other.meta_),
accessAllThreadsLock_(other.accessAllThreadsLock_),
lock_(other.lock_), lock_(other.lock_),
id_(other.id_) { id_(other.id_) {
other.id_ = 0; other.id_ = 0;
other.accessAllThreadsLock_ = nullptr;
other.lock_ = nullptr; other.lock_ = nullptr;
} }
...@@ -342,23 +338,20 @@ class ThreadLocalPtr { ...@@ -342,23 +338,20 @@ class ThreadLocalPtr {
assert(&meta_ == &other.meta_); assert(&meta_ == &other.meta_);
assert(lock_ == nullptr); assert(lock_ == nullptr);
using std::swap; using std::swap;
swap(accessAllThreadsLock_, other.accessAllThreadsLock_);
swap(lock_, other.lock_); swap(lock_, other.lock_);
swap(id_, other.id_); swap(id_, other.id_);
} }
Accessor() Accessor()
: meta_(threadlocal_detail::StaticMeta<Tag>::instance()), : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
accessAllThreadsLock_(nullptr),
lock_(nullptr), lock_(nullptr),
id_(0) {} id_(0) {
}
private: private:
explicit Accessor(uint32_t id) explicit Accessor(uint32_t id)
: meta_(threadlocal_detail::StaticMeta<Tag>::instance()), : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
accessAllThreadsLock_(&meta_.accessAllThreadsLock_),
lock_(&meta_.lock_) { lock_(&meta_.lock_) {
accessAllThreadsLock_->lock();
lock_->lock(); lock_->lock();
id_ = id; id_ = id;
} }
...@@ -366,11 +359,8 @@ class ThreadLocalPtr { ...@@ -366,11 +359,8 @@ class ThreadLocalPtr {
void release() { void release() {
if (lock_) { if (lock_) {
lock_->unlock(); lock_->unlock();
DCHECK(accessAllThreadsLock_ != nullptr);
accessAllThreadsLock_->unlock();
id_ = 0; id_ = 0;
lock_ = nullptr; lock_ = nullptr;
accessAllThreadsLock_ = nullptr;
} }
} }
}; };
......
...@@ -44,8 +44,6 @@ void StaticMetaBase::onThreadExit(void* ptr) { ...@@ -44,8 +44,6 @@ void StaticMetaBase::onThreadExit(void* ptr) {
pthread_setspecific(meta.pthreadKey_, nullptr); pthread_setspecific(meta.pthreadKey_, nullptr);
}; };
{
SharedMutex::ReadHolder rlock(meta.accessAllThreadsLock_);
{ {
std::lock_guard<std::mutex> g(meta.lock_); std::lock_guard<std::mutex> g(meta.lock_);
meta.erase(&(*threadEntry)); meta.erase(&(*threadEntry));
...@@ -64,7 +62,6 @@ void StaticMetaBase::onThreadExit(void* ptr) { ...@@ -64,7 +62,6 @@ void StaticMetaBase::onThreadExit(void* ptr) {
} }
} }
} }
}
free(threadEntry->elements); free(threadEntry->elements);
threadEntry->elements = nullptr; threadEntry->elements = nullptr;
threadEntry->meta = nullptr; threadEntry->meta = nullptr;
......
...@@ -296,7 +296,6 @@ struct StaticMetaBase { ...@@ -296,7 +296,6 @@ struct StaticMetaBase {
uint32_t nextId_; uint32_t nextId_;
std::vector<uint32_t> freeIds_; std::vector<uint32_t> freeIds_;
std::mutex lock_; std::mutex lock_;
SharedMutex accessAllThreadsLock_;
pthread_key_t pthreadKey_; pthread_key_t pthreadKey_;
ThreadEntry head_; ThreadEntry head_;
ThreadEntry* (*threadEntry_)(); ThreadEntry* (*threadEntry_)();
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <folly/ThreadCachedInt.h> #include <folly/ThreadCachedInt.h>
#include <atomic> #include <atomic>
#include <condition_variable>
#include <thread> #include <thread>
#include <glog/logging.h> #include <glog/logging.h>
...@@ -29,139 +28,6 @@ ...@@ -29,139 +28,6 @@
using namespace folly; using namespace folly;
using std::unique_ptr;
using std::vector;
using Counter = ThreadCachedInt<int64_t>;
class ThreadCachedIntTest : public testing::Test {
public:
uint32_t GetDeadThreadsTotal(const Counter& counter) {
return counter.readFast();
}
};
// Multithreaded tests. Creates a specified number of threads each of
// which iterates a different amount and dies.
namespace {
// Set cacheSize to be large so cached data moves to target_ only when
// thread dies.
Counter g_counter_for_mt_slow(0, UINT32_MAX);
Counter g_counter_for_mt_fast(0, UINT32_MAX);
// Used to sync between threads. The value of this variable is the
// maximum iteration index upto which Runner() is allowed to go.
uint32_t g_sync_for_mt(0);
std::condition_variable cv;
std::mutex cv_m;
// Performs the specified number of iterations. Within each
// iteration, it increments counter 10 times. At the beginning of
// each iteration it checks g_sync_for_mt to see if it can proceed,
// otherwise goes into a loop sleeping and rechecking.
void Runner(Counter* counter, uint32_t iterations) {
for (uint32_t i = 0; i < iterations; ++i) {
std::unique_lock<std::mutex> lk(cv_m);
cv.wait(lk, [i] { return i < g_sync_for_mt; });
for (uint32_t j = 0; j < 10; ++j) {
counter->increment(1);
}
}
}
}
// Slow test with fewer threads where there are more busy waits and
// many calls to readFull(). This attempts to test as many of the
// code paths in Counter as possible to ensure that counter values are
// properly passed from thread local state, both at calls to
// readFull() and at thread death.
TEST_F(ThreadCachedIntTest, MultithreadedSlow) {
static constexpr uint32_t kNumThreads = 20;
g_sync_for_mt = 0;
vector<unique_ptr<std::thread>> threads(kNumThreads);
// Creates kNumThreads threads. Each thread performs a different
// number of iterations in Runner() - threads[0] performs 1
// iteration, threads[1] performs 2 iterations, threads[2] performs
// 3 iterations, and so on.
for (uint32_t i = 0; i < kNumThreads; ++i) {
threads[i].reset(new std::thread(Runner, &g_counter_for_mt_slow, i + 1));
}
// Variable to grab current counter value.
int32_t counter_value;
// The expected value of the counter.
int32_t total = 0;
// The expected value of GetDeadThreadsTotal().
int32_t dead_total = 0;
// Each iteration of the following thread allows one additional
// iteration of the threads. Given that the threads perform
// different number of iterations from 1 through kNumThreads, one
// thread will complete in each of the iterations of the loop below.
for (uint32_t i = 0; i < kNumThreads; ++i) {
// Allow upto iteration i on all threads.
{
std::lock_guard<std::mutex> lk(cv_m);
g_sync_for_mt = i + 1;
}
cv.notify_all();
total += (kNumThreads - i) * 10;
// Loop until the counter reaches its expected value.
do {
counter_value = g_counter_for_mt_slow.readFull();
} while (counter_value < total);
// All threads have done what they can until iteration i, now make
// sure they don't go further by checking 10 more times in the
// following loop.
for (uint32_t j = 0; j < 10; ++j) {
counter_value = g_counter_for_mt_slow.readFull();
EXPECT_EQ(total, counter_value);
}
dead_total += (i + 1) * 10;
EXPECT_GE(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow));
}
// All threads are done.
for (uint32_t i = 0; i < kNumThreads; ++i) {
threads[i]->join();
}
counter_value = g_counter_for_mt_slow.readFull();
EXPECT_EQ(total, counter_value);
EXPECT_EQ(total, dead_total);
EXPECT_EQ(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow));
}
// Fast test with lots of threads and only one call to readFull()
// at the end.
TEST_F(ThreadCachedIntTest, MultithreadedFast) {
static constexpr uint32_t kNumThreads = 1000;
g_sync_for_mt = 0;
vector<unique_ptr<std::thread>> threads(kNumThreads);
// Creates kNumThreads threads. Each thread performs a different
// number of iterations in Runner() - threads[0] performs 1
// iteration, threads[1] performs 2 iterations, threads[2] performs
// 3 iterations, and so on.
for (uint32_t i = 0; i < kNumThreads; ++i) {
threads[i].reset(new std::thread(Runner, &g_counter_for_mt_fast, i + 1));
}
// Let the threads run to completion.
{
std::lock_guard<std::mutex> lk(cv_m);
g_sync_for_mt = kNumThreads;
}
cv.notify_all();
// The expected value of the counter.
uint32_t total = 0;
for (uint32_t i = 0; i < kNumThreads; ++i) {
total += (kNumThreads - i) * 10;
}
// Wait for all threads to complete.
for (uint32_t i = 0; i < kNumThreads; ++i) {
threads[i]->join();
}
int32_t counter_value = g_counter_for_mt_fast.readFull();
EXPECT_EQ(total, counter_value);
EXPECT_EQ(total, GetDeadThreadsTotal(g_counter_for_mt_fast));
}
TEST(ThreadCachedInt, SingleThreadedNotCached) { TEST(ThreadCachedInt, SingleThreadedNotCached) {
ThreadCachedInt<int64_t> val(0, 0); ThreadCachedInt<int64_t> val(0, 0);
EXPECT_EQ(0, val.readFast()); EXPECT_EQ(0, val.readFast());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment