Commit 5ff51e97 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by facebook-github-bot-1

Fix TLRefCount race around thread local destruction and fix RefCount unit test

Reviewed By: pavlo-fb

Differential Revision: D2708425

fb-gh-sync-id: 665d077210503df4f4e8aa8f88ce5b9b277582f3
parent 623cc983
......@@ -15,6 +15,7 @@
*/
#pragma once
#include <folly/Baton.h>
#include <folly/ThreadLocal.h>
namespace folly {
......@@ -26,6 +27,10 @@ class TLRefCount {
TLRefCount() :
localCount_([&]() {
return new LocalRefCount(*this);
}),
collectGuard_(&collectBaton_, [](void* p) {
auto baton = reinterpret_cast<folly::Baton<>*>(p);
baton->post();
}) {
}
......@@ -91,6 +96,9 @@ class TLRefCount {
count.collect();
}
collectGuard_.reset();
collectBaton_.wait();
state_ = State::GLOBAL;
}
......@@ -106,7 +114,11 @@ class TLRefCount {
class LocalRefCount {
public:
explicit LocalRefCount(TLRefCount& refCount) :
refCount_(refCount) {}
refCount_(refCount) {
std::lock_guard<std::mutex> lg(refCount.globalMutex_);
collectGuard_ = refCount.collectGuard_;
}
~LocalRefCount() {
collect();
......@@ -115,13 +127,13 @@ class TLRefCount {
void collect() {
std::lock_guard<std::mutex> lg(collectMutex_);
if (collectDone_) {
if (!collectGuard_) {
return;
}
collectCount_ = count_;
refCount_.globalCount_ += collectCount_;
collectDone_ = true;
collectGuard_.reset();
}
bool operator++() {
......@@ -143,7 +155,7 @@ class TLRefCount {
if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
std::lock_guard<std::mutex> lg(collectMutex_);
if (!collectDone_) {
if (collectGuard_) {
return true;
}
if (collectCount_ != count) {
......@@ -159,13 +171,15 @@ class TLRefCount {
std::mutex collectMutex_;
Int collectCount_{0};
bool collectDone_{false};
std::shared_ptr<void> collectGuard_;
};
std::atomic<State> state_{State::LOCAL};
folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
std::atomic<int64_t> globalCount_{1};
std::mutex globalMutex_;
folly::Baton<> collectBaton_;
std::shared_ptr<void> collectGuard_;
};
}
......@@ -35,12 +35,16 @@ void basicTest() {
folly::Baton<> b;
std::vector<std::thread> ts;
folly::Baton<> threadBatons[numThreads];
for (size_t t = 0; t < numThreads; ++t) {
ts.emplace_back([&count, &b, &got0, numIters, t]() {
ts.emplace_back([&count, &b, &got0, numIters, t, &threadBatons]() {
for (size_t i = 0; i < numIters; ++i) {
auto ret = ++count;
EXPECT_TRUE(ret > 1);
if (i == 0) {
threadBatons[t].post();
}
}
if (t == 0) {
......@@ -58,10 +62,14 @@ void basicTest() {
});
}
for (size_t t = 0; t < numThreads; ++t) {
threadBatons[t].wait();
}
b.wait();
count.useGlobal();
EXPECT_TRUE(--count > 0);
--count;
for (auto& t: ts) {
t.join();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment