Commit 5ff51e97 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by facebook-github-bot-1

Fix TLRefCount race around thread local destruction and fix RefCount unit test

Reviewed By: pavlo-fb

Differential Revision: D2708425

fb-gh-sync-id: 665d077210503df4f4e8aa8f88ce5b9b277582f3
parent 623cc983
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*/ */
#pragma once #pragma once
#include <folly/Baton.h>
#include <folly/ThreadLocal.h> #include <folly/ThreadLocal.h>
namespace folly { namespace folly {
...@@ -26,6 +27,10 @@ class TLRefCount { ...@@ -26,6 +27,10 @@ class TLRefCount {
TLRefCount() : TLRefCount() :
localCount_([&]() { localCount_([&]() {
return new LocalRefCount(*this); return new LocalRefCount(*this);
}),
collectGuard_(&collectBaton_, [](void* p) {
auto baton = reinterpret_cast<folly::Baton<>*>(p);
baton->post();
}) { }) {
} }
...@@ -91,6 +96,9 @@ class TLRefCount { ...@@ -91,6 +96,9 @@ class TLRefCount {
count.collect(); count.collect();
} }
collectGuard_.reset();
collectBaton_.wait();
state_ = State::GLOBAL; state_ = State::GLOBAL;
} }
...@@ -106,7 +114,11 @@ class TLRefCount { ...@@ -106,7 +114,11 @@ class TLRefCount {
class LocalRefCount { class LocalRefCount {
public: public:
explicit LocalRefCount(TLRefCount& refCount) : explicit LocalRefCount(TLRefCount& refCount) :
refCount_(refCount) {} refCount_(refCount) {
std::lock_guard<std::mutex> lg(refCount.globalMutex_);
collectGuard_ = refCount.collectGuard_;
}
~LocalRefCount() { ~LocalRefCount() {
collect(); collect();
...@@ -115,13 +127,13 @@ class TLRefCount { ...@@ -115,13 +127,13 @@ class TLRefCount {
void collect() { void collect() {
std::lock_guard<std::mutex> lg(collectMutex_); std::lock_guard<std::mutex> lg(collectMutex_);
if (collectDone_) { if (!collectGuard_) {
return; return;
} }
collectCount_ = count_; collectCount_ = count_;
refCount_.globalCount_ += collectCount_; refCount_.globalCount_ += collectCount_;
collectDone_ = true; collectGuard_.reset();
} }
bool operator++() { bool operator++() {
...@@ -143,7 +155,7 @@ class TLRefCount { ...@@ -143,7 +155,7 @@ class TLRefCount {
if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) { if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
std::lock_guard<std::mutex> lg(collectMutex_); std::lock_guard<std::mutex> lg(collectMutex_);
if (!collectDone_) { if (collectGuard_) {
return true; return true;
} }
if (collectCount_ != count) { if (collectCount_ != count) {
...@@ -159,13 +171,15 @@ class TLRefCount { ...@@ -159,13 +171,15 @@ class TLRefCount {
std::mutex collectMutex_; std::mutex collectMutex_;
Int collectCount_{0}; Int collectCount_{0};
bool collectDone_{false}; std::shared_ptr<void> collectGuard_;
}; };
std::atomic<State> state_{State::LOCAL}; std::atomic<State> state_{State::LOCAL};
folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_; folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
std::atomic<int64_t> globalCount_{1}; std::atomic<int64_t> globalCount_{1};
std::mutex globalMutex_; std::mutex globalMutex_;
folly::Baton<> collectBaton_;
std::shared_ptr<void> collectGuard_;
}; };
} }
...@@ -35,12 +35,16 @@ void basicTest() { ...@@ -35,12 +35,16 @@ void basicTest() {
folly::Baton<> b; folly::Baton<> b;
std::vector<std::thread> ts; std::vector<std::thread> ts;
folly::Baton<> threadBatons[numThreads];
for (size_t t = 0; t < numThreads; ++t) { for (size_t t = 0; t < numThreads; ++t) {
ts.emplace_back([&count, &b, &got0, numIters, t]() { ts.emplace_back([&count, &b, &got0, numIters, t, &threadBatons]() {
for (size_t i = 0; i < numIters; ++i) { for (size_t i = 0; i < numIters; ++i) {
auto ret = ++count; auto ret = ++count;
EXPECT_TRUE(ret > 1); EXPECT_TRUE(ret > 1);
if (i == 0) {
threadBatons[t].post();
}
} }
if (t == 0) { if (t == 0) {
...@@ -58,10 +62,14 @@ void basicTest() { ...@@ -58,10 +62,14 @@ void basicTest() {
}); });
} }
for (size_t t = 0; t < numThreads; ++t) {
threadBatons[t].wait();
}
b.wait(); b.wait();
count.useGlobal(); count.useGlobal();
EXPECT_TRUE(--count > 0); --count;
for (auto& t: ts) { for (auto& t: ts) {
t.join(); t.join();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment