Commit 608edf73 authored by Dave Watson's avatar Dave Watson Committed by Facebook Github Bot

Fix fork behavior

Summary:
hazptr_priv_list needs to be reinitialized on fork, as well as ODR violations.

This broke stuff when trying to land D7164130

TL;DR the thread_locals need to be in a .cpp file, the templates don't always get merged correctly, especially in the face of dlopen.

You can, however, use a thread_local * pointer cache to the object.  This is basically what folly::SingletonThreadLocal does.

Folly's ThreadLocal runs destructors when forking even too, so nothing special is required.

Reviewed By: yfeldblum

Differential Revision: D7256905

fbshipit-source-id: c817167b5c3db27fa929feaa39295fd939c1cb4c
parent bb5932ed
......@@ -84,6 +84,7 @@
#define HAZPTR_STATS false
#endif
#include <folly/SingletonThreadLocal.h>
#include <folly/concurrency/CacheLocality.h>
#include <folly/experimental/hazptr/debug.h>
#include <folly/synchronization/AsymmetricMemoryBarrier.h>
......@@ -161,8 +162,8 @@ bool hazptr_tc_enabled();
bool hazptr_priv_enabled();
hazptr_tc* hazptr_tc_tls();
void hazptr_tc_init();
void hazptr_tc_shutdown();
void hazptr_tc_init(hazptr_tc& tc);
void hazptr_tc_shutdown(hazptr_tc& tc);
hazptr_rec* hazptr_tc_try_get();
bool hazptr_tc_try_put(hazptr_rec* hprec);
......@@ -317,8 +318,8 @@ static_assert(
folly::kMscVer || std::is_trivial<hazptr_priv>::value,
"hazptr_priv must be trivial to avoid a branch to check initialization");
void hazptr_priv_init();
void hazptr_priv_shutdown();
void hazptr_priv_init(hazptr_priv& priv);
void hazptr_priv_shutdown(hazptr_priv& priv);
bool hazptr_priv_try_retire(hazptr_obj* obj);
inline void hazptr_priv_list::insert(hazptr_priv* rec) {
......@@ -364,45 +365,30 @@ inline void hazptr_priv_list::collect(hazptr_obj*& head, hazptr_obj*& tail) {
}
}
/** hazptr_tls_life */
struct hazptr_tls_life {
hazptr_tls_life();
~hazptr_tls_life();
};
void tls_life_odr_use();
/** tls globals */
#if HAZPTR_ENABLE_TLS
#define HAZPTR_TLS_EXPANSION thread_local
#else
#define HAZPTR_TLS_EXPANSION
#endif
FOLLY_PUSH_WARNING
#if __clang__
FOLLY_GCC_DISABLE_WARNING("-Wglobal-constructors")
#endif
template <typename>
struct hazptr_tls_globals_ {
static HAZPTR_TLS_EXPANSION hazptr_tls_state tls_state;
static HAZPTR_TLS_EXPANSION hazptr_tc tc;
static HAZPTR_TLS_EXPANSION hazptr_priv priv;
static HAZPTR_TLS_EXPANSION hazptr_tls_life tls_life; // last
hazptr_tls_state tls_state{TLS_UNINITIALIZED};
hazptr_tc tc;
hazptr_priv priv;
hazptr_tls_globals_() {
HAZPTR_DEBUG_PRINT(this);
tls_state = TLS_ALIVE;
hazptr_tc_init(tc);
hazptr_priv_init(priv);
}
~hazptr_tls_globals_() {
HAZPTR_DEBUG_PRINT(this);
CHECK(tls_state == TLS_ALIVE);
hazptr_tc_shutdown(tc);
hazptr_priv_shutdown(priv);
tls_state = TLS_DESTROYED;
}
};
template <typename T>
HAZPTR_TLS_EXPANSION hazptr_tls_state hazptr_tls_globals_<T>::tls_state =
TLS_UNINITIALIZED;
template <typename T>
HAZPTR_TLS_EXPANSION hazptr_tc hazptr_tls_globals_<T>::tc;
template <typename T>
HAZPTR_TLS_EXPANSION hazptr_priv hazptr_tls_globals_<T>::priv;
template <typename T>
HAZPTR_TLS_EXPANSION hazptr_tls_life hazptr_tls_globals_<T>::tls_life; // last
FOLLY_POP_WARNING
#undef HAZPTR_TLS_EXPANSION
using hazptr_tls_globals = hazptr_tls_globals_<void>;
FOLLY_ALWAYS_INLINE hazptr_tls_globals_& hazptr_tls_globals() {
return folly::SingletonThreadLocal<hazptr_tls_globals_, void>::get();
}
/**
* hazptr_domain
......@@ -1195,19 +1181,17 @@ FOLLY_ALWAYS_INLINE size_t hazptr_tc::count() {
/** hazptr_tc free functions */
FOLLY_ALWAYS_INLINE hazptr_tc* hazptr_tc_tls() {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
if (LIKELY(hazptr_tls_globals::tls_state == TLS_ALIVE)) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
return &hazptr_tls_globals::tc;
} else if (hazptr_tls_globals::tls_state == TLS_UNINITIALIZED) {
tls_life_odr_use();
return &hazptr_tls_globals::tc;
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
if (LIKELY(hazptr_tls_globals().tls_state == TLS_ALIVE)) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
return &hazptr_tls_globals().tc;
} else if (hazptr_tls_globals().tls_state == TLS_UNINITIALIZED) {
return &hazptr_tls_globals().tc;
}
return nullptr;
}
inline void hazptr_tc_init() {
auto& tc = hazptr_tls_globals::tc;
inline void hazptr_tc_init(hazptr_tc& tc) {
HAZPTR_DEBUG_PRINT(&tc);
tc.count_ = 0;
if (kIsDebug) {
......@@ -1215,8 +1199,7 @@ inline void hazptr_tc_init() {
}
}
inline void hazptr_tc_shutdown() {
auto& tc = hazptr_tls_globals::tc;
inline void hazptr_tc_shutdown(hazptr_tc& tc) {
HAZPTR_DEBUG_PRINT(&tc);
for (size_t i = 0; i < tc.count_; ++i) {
tc.entry_[i].evict();
......@@ -1225,22 +1208,21 @@ inline void hazptr_tc_shutdown() {
FOLLY_ALWAYS_INLINE hazptr_rec* hazptr_tc_try_get() {
HAZPTR_DEBUG_PRINT(TLS_UNINITIALIZED << TLS_ALIVE << TLS_DESTROYED);
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
if (LIKELY(hazptr_tls_globals::tls_state == TLS_ALIVE)) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
return hazptr_tls_globals::tc.get();
} else if (hazptr_tls_globals::tls_state == TLS_UNINITIALIZED) {
tls_life_odr_use();
return hazptr_tls_globals::tc.get();
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
if (LIKELY(hazptr_tls_globals().tls_state == TLS_ALIVE)) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
return hazptr_tls_globals().tc.get();
} else if (hazptr_tls_globals().tls_state == TLS_UNINITIALIZED) {
return hazptr_tls_globals().tc.get();
}
return nullptr;
}
FOLLY_ALWAYS_INLINE bool hazptr_tc_try_put(hazptr_rec* hprec) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
if (LIKELY(hazptr_tls_globals::tls_state == TLS_ALIVE)) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
return hazptr_tls_globals::tc.put(hprec);
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
if (LIKELY(hazptr_tls_globals().tls_state == TLS_ALIVE)) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
return hazptr_tls_globals().tc.put(hprec);
}
return false;
}
......@@ -1249,14 +1231,12 @@ FOLLY_ALWAYS_INLINE bool hazptr_tc_try_put(hazptr_rec* hprec) {
* hazptr_priv
*/
inline void hazptr_priv_init() {
auto& priv = hazptr_tls_globals::priv;
inline void hazptr_priv_init(hazptr_priv& priv) {
HAZPTR_DEBUG_PRINT(&priv);
priv.init();
}
inline void hazptr_priv_shutdown() {
auto& priv = hazptr_tls_globals::priv;
inline void hazptr_priv_shutdown(hazptr_priv& priv) {
HAZPTR_DEBUG_PRINT(&priv);
DCHECK(priv.active());
priv.clear_active();
......@@ -1267,46 +1247,19 @@ inline void hazptr_priv_shutdown() {
}
inline bool hazptr_priv_try_retire(hazptr_obj* obj) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
if (hazptr_tls_globals::tls_state == TLS_ALIVE) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
hazptr_tls_globals::priv.push(obj);
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
if (hazptr_tls_globals().tls_state == TLS_ALIVE) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
hazptr_tls_globals().priv.push(obj);
return true;
} else if (hazptr_tls_globals::tls_state == TLS_UNINITIALIZED) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
tls_life_odr_use();
hazptr_tls_globals::priv.push(obj);
} else if (hazptr_tls_globals().tls_state == TLS_UNINITIALIZED) {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals().tls_state);
hazptr_tls_globals().priv.push(obj);
return true;
}
return false;
}
/** hazptr_tls_life */
inline void tls_life_odr_use() {
HAZPTR_DEBUG_PRINT(hazptr_tls_globals::tls_state);
CHECK(hazptr_tls_globals::tls_state == TLS_UNINITIALIZED);
auto volatile tlsOdrUse = &hazptr_tls_globals::tls_life;
CHECK(tlsOdrUse != nullptr);
HAZPTR_DEBUG_PRINT(tlsOdrUse);
}
inline hazptr_tls_life::hazptr_tls_life() {
HAZPTR_DEBUG_PRINT(this);
CHECK(hazptr_tls_globals::tls_state == TLS_UNINITIALIZED);
hazptr_tc_init();
hazptr_priv_init();
hazptr_tls_globals::tls_state = TLS_ALIVE;
}
inline hazptr_tls_life::~hazptr_tls_life() {
HAZPTR_DEBUG_PRINT(this);
CHECK(hazptr_tls_globals::tls_state == TLS_ALIVE);
hazptr_tc_shutdown();
hazptr_priv_shutdown();
hazptr_tls_globals::tls_state = TLS_DESTROYED;
}
/** hazptr_obj_batch */
/* Only for default domain. Supports only hazptr_obj_base_refcounted
* and a thread-safe access only, for now. */
......
......@@ -29,6 +29,8 @@
#include <folly/portability/GFlags.h>
#include <folly/portability/GTest.h>
#include <condition_variable>
#include <thread>
DEFINE_int32(num_threads, 5, "Number of threads");
......@@ -674,3 +676,60 @@ TEST_F(HazptrTest, FreeFunctionCleanup) {
CHECK_EQ(destroyed.load(), 2);
}
}
TEST_F(HazptrTest, ForkTest) {
struct Foo : hazptr_obj_base<Foo> {
int a;
};
std::mutex m;
std::condition_variable cv;
std::condition_variable cv2;
bool ready = false;
bool ready2 = false;
auto mkthread = [&]() {
hazptr_holder h;
auto p = new Foo;
std::atomic<Foo*> ap{p};
h.get_protected<Foo>(p);
p->retire();
{
std::unique_lock<std::mutex> lk(m);
ready = true;
cv.notify_one();
cv2.wait(lk, [&] { return ready2; });
}
};
std::thread t(mkthread);
hazptr_holder h;
auto p = new Foo;
std::atomic<Foo*> ap{p};
h.get_protected<Foo>(p);
p->retire();
{
std::unique_lock<std::mutex> lk(m);
cv.wait(lk, [&] { return ready; });
}
auto pid = fork();
CHECK_GE(pid, 0);
if (pid) {
{
std::lock_guard<std::mutex> g(m);
ready2 = true;
cv2.notify_one();
}
t.join();
int status;
wait(&status);
CHECK_EQ(status, 0);
} else {
// child
std::thread tchild(mkthread);
{
std::lock_guard<std::mutex> g(m);
ready2 = true;
cv2.notify_one();
}
tchild.join();
_exit(0); // Do not print gtest results
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment