Commit 88ae9ac7 authored by Tudor Bosman's avatar Tudor Bosman Committed by Sara Golemon

Make ThreadLocalPtr behave sanely around fork()

Summary:
Threads and fork still don't mix, but we shouldn't help you shoot yourself in
the foot if you decide to do it.

Test Plan: test added

Reviewed By: mshneer@fb.com

FB internal diff: D911224
parent 206a0372
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "folly/Exception.h"
#include "folly/Foreach.h" #include "folly/Foreach.h"
#include "folly/Malloc.h" #include "folly/Malloc.h"
...@@ -146,9 +147,9 @@ struct StaticMeta { ...@@ -146,9 +147,9 @@ struct StaticMeta {
static StaticMeta<Tag>& instance() { static StaticMeta<Tag>& instance() {
// Leak it on exit, there's only one per process and we don't have to // Leak it on exit, there's only one per process and we don't have to
// worry about synchronization with exiting threads. // worry about synchronization with exiting threads.
static bool constructed = (inst = new StaticMeta<Tag>()); static bool constructed = (inst_ = new StaticMeta<Tag>());
(void)constructed; // suppress unused warning (void)constructed; // suppress unused warning
return *inst; return *inst_;
} }
int nextId_; int nextId_;
...@@ -171,33 +172,36 @@ struct StaticMeta { ...@@ -171,33 +172,36 @@ struct StaticMeta {
} }
static __thread ThreadEntry threadEntry_; static __thread ThreadEntry threadEntry_;
static StaticMeta<Tag>* inst; static StaticMeta<Tag>* inst_;
StaticMeta() : nextId_(1) { StaticMeta() : nextId_(1) {
head_.next = head_.prev = &head_; head_.next = head_.prev = &head_;
int ret = pthread_key_create(&pthreadKey_, &onThreadExit); int ret = pthread_key_create(&pthreadKey_, &onThreadExit);
if (ret != 0) { checkPosixError(ret, "pthread_key_create failed");
std::string msg;
switch (ret) { ret = pthread_atfork(/*prepare*/ &StaticMeta::preFork,
case EAGAIN: /*parent*/ &StaticMeta::onForkParent,
char buf[100]; /*child*/ &StaticMeta::onForkChild);
snprintf(buf, sizeof(buf), "PTHREAD_KEYS_MAX (%d) is exceeded", checkPosixError(ret, "pthread_atfork failed");
PTHREAD_KEYS_MAX);
msg = buf;
break;
case ENOMEM:
msg = "Out-of-memory";
break;
default:
msg = "(unknown error)";
}
throw std::runtime_error("pthread_key_create failed: " + msg);
}
} }
~StaticMeta() { ~StaticMeta() {
LOG(FATAL) << "StaticMeta lives forever!"; LOG(FATAL) << "StaticMeta lives forever!";
} }
static void preFork(void) {
instance().lock_.lock(); // Make sure it's created
}
static void onForkParent(void) {
inst_->lock_.unlock();
}
static void onForkChild(void) {
inst_->head_.next = inst_->head_.prev = &inst_->head_;
inst_->push_back(&threadEntry_); // only the current thread survives
inst_->lock_.unlock();
}
static void onThreadExit(void* ptr) { static void onThreadExit(void* ptr) {
auto & meta = instance(); auto & meta = instance();
DCHECK_EQ(ptr, &meta); DCHECK_EQ(ptr, &meta);
...@@ -328,7 +332,7 @@ struct StaticMeta { ...@@ -328,7 +332,7 @@ struct StaticMeta {
}; };
template <class Tag> __thread ThreadEntry StaticMeta<Tag>::threadEntry_ = {0}; template <class Tag> __thread ThreadEntry StaticMeta<Tag>::threadEntry_ = {0};
template <class Tag> StaticMeta<Tag>* StaticMeta<Tag>::inst = nullptr; template <class Tag> StaticMeta<Tag>* StaticMeta<Tag>::inst_ = nullptr;
} // namespace threadlocal_detail } // namespace threadlocal_detail
} // namespace folly } // namespace folly
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "folly/ThreadLocal.h" #include "folly/ThreadLocal.h"
#include <sys/types.h>
#include <sys/wait.h>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <set> #include <set>
...@@ -23,6 +25,7 @@ ...@@ -23,6 +25,7 @@
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include <thread> #include <thread>
#include <unistd.h>
#include <boost/thread/tss.hpp> #include <boost/thread/tss.hpp>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <gflags/gflags.h> #include <gflags/gflags.h>
...@@ -295,6 +298,103 @@ TEST(ThreadLocal, Movable2) { ...@@ -295,6 +298,103 @@ TEST(ThreadLocal, Movable2) {
EXPECT_EQ(4, tls.size()); EXPECT_EQ(4, tls.size());
} }
// Yes, threads and fork don't mix
// (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
// stupid or desperate enough to try, we shouldn't stand in your way.
namespace {
class HoldsOne {
public:
HoldsOne() : value_(1) { }
// Do an actual access to catch the buggy case where this == nullptr
int value() const { return value_; }
private:
int value_;
};
struct HoldsOneTag {};
ThreadLocal<HoldsOne, HoldsOneTag> ptr;
int totalValue() {
int value = 0;
for (auto& p : ptr.accessAllThreads()) {
value += p.value();
}
return value;
}
} // namespace
TEST(ThreadLocal, Fork) {
EXPECT_EQ(1, ptr->value()); // ensure created
EXPECT_EQ(1, totalValue());
// Spawn a new thread
std::mutex mutex;
bool started = false;
std::condition_variable startedCond;
bool stopped = false;
std::condition_variable stoppedCond;
std::thread t([&] () {
EXPECT_EQ(1, ptr->value()); // ensure created
{
std::unique_lock<std::mutex> lock(mutex);
started = true;
startedCond.notify_all();
}
{
std::unique_lock<std::mutex> lock(mutex);
while (!stopped) {
stoppedCond.wait(lock);
}
}
});
{
std::unique_lock<std::mutex> lock(mutex);
while (!started) {
startedCond.wait(lock);
}
}
EXPECT_EQ(2, totalValue());
pid_t pid = fork();
if (pid == 0) {
// in child
int v = totalValue();
// exit successfully if v == 1 (one thread)
// diagnostic error code otherwise :)
switch (v) {
case 1: _exit(0);
case 0: _exit(1);
}
_exit(2);
} else if (pid > 0) {
// in parent
int status;
EXPECT_EQ(pid, waitpid(pid, &status, 0));
EXPECT_TRUE(WIFEXITED(status));
EXPECT_EQ(0, WEXITSTATUS(status));
} else {
EXPECT_TRUE(false) << "fork failed";
}
EXPECT_EQ(2, totalValue());
{
std::unique_lock<std::mutex> lock(mutex);
stopped = true;
stoppedCond.notify_all();
}
t.join();
EXPECT_EQ(1, totalValue());
}
// Simple reference implementation using pthread_get_specific // Simple reference implementation using pthread_get_specific
template<typename T> template<typename T>
class PThreadGetSpecific { class PThreadGetSpecific {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment