Commit 2642bd3d authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Fix a race in Observable context destruction

Summary: In the subscribe callback It's possible that we lock the Context shared_ptr and while update is running, all other shared_ptr's are released. This will result in Context to be destroyed from the wrong thread (thread runnning subcribe callback), which is not desired.

Reviewed By: yfeldblum

Differential Revision: D4964605

fbshipit-source-id: 285327a6873ccb7393fa3067ba7e612c29dbc454
parent 79cb7776
......@@ -22,7 +22,9 @@ template <typename Observable, typename Traits>
class ObserverCreator<Observable, Traits>::Context {
public:
template <typename... Args>
Context(Args&&... args) : observable_(std::forward<Args>(args)...) {}
Context(Args&&... args) : observable_(std::forward<Args>(args)...) {
updateValue();
}
~Context() {
if (value_.copy()) {
......@@ -47,21 +49,11 @@ class ObserverCreator<Observable, Traits>::Context {
// callbacks (getting new value from observable and storing it into value_
// is not atomic).
std::lock_guard<std::mutex> lg(updateMutex_);
{
auto newValue = Traits::get(observable_);
if (!newValue) {
throw std::logic_error("Observable returned nullptr.");
}
value_.swap(newValue);
}
updateValue();
bool expected = false;
if (updateRequested_.compare_exchange_strong(expected, true)) {
if (auto core = coreWeak_.lock()) {
observer_detail::ObserverManager::scheduleRefreshNewVersion(
std::move(core));
}
observer_detail::ObserverManager::scheduleRefreshNewVersion(coreWeak_);
}
}
......@@ -71,6 +63,14 @@ class ObserverCreator<Observable, Traits>::Context {
}
private:
void updateValue() {
auto newValue = Traits::get(observable_);
if (!newValue) {
throw std::logic_error("Observable returned nullptr.");
}
value_.swap(newValue);
}
folly::Synchronized<std::shared_ptr<const T>> value_;
std::atomic<bool> updateRequested_{false};
......@@ -89,24 +89,68 @@ ObserverCreator<Observable, Traits>::ObserverCreator(Args&&... args)
template <typename Observable, typename Traits>
Observer<typename ObserverCreator<Observable, Traits>::T>
ObserverCreator<Observable, Traits>::getObserver()&& {
auto core = observer_detail::Core::create([context = context_]() {
// This master shared_ptr allows grabbing derived weak_ptrs, pointing to the
// the same Context object, but using a separate reference count. Master
// shared_ptr destructor then blocks until all shared_ptrs obtained from
// derived weak_ptrs are released.
class ContextMasterPointer {
public:
explicit ContextMasterPointer(std::shared_ptr<Context> context)
: contextMaster_(std::move(context)),
context_(
contextMaster_.get(),
[destroyBaton = destroyBaton_](Context*) {
destroyBaton->post();
}) {}
~ContextMasterPointer() {
if (context_) {
context_.reset();
destroyBaton_->wait();
}
}
ContextMasterPointer(const ContextMasterPointer&) = delete;
ContextMasterPointer(ContextMasterPointer&&) = default;
ContextMasterPointer& operator=(const ContextMasterPointer&) = delete;
ContextMasterPointer& operator=(ContextMasterPointer&&) = default;
Context* operator->() const {
return contextMaster_.get();
}
std::weak_ptr<Context> get_weak() {
return context_;
}
private:
std::shared_ptr<folly::Baton<>> destroyBaton_{
std::make_shared<folly::Baton<>>()};
std::shared_ptr<Context> contextMaster_;
std::shared_ptr<Context> context_;
};
// We want to make sure that Context can only be destroyed when Core is
// destroyed. So we have to avoid the situation when subscribe callback is
// locking Context shared_ptr and remains the last to release it.
// We solve this by having Core hold the master shared_ptr and subscription
// callback gets derived weak_ptr.
ContextMasterPointer contextMaster(context_);
auto contextWeak = contextMaster.get_weak();
auto observer = makeObserver([context = std::move(contextMaster)]() {
return context->get();
});
context_->setCore(core);
context_->subscribe([contextWeak = std::weak_ptr<Context>(context_)] {
context_->setCore(observer.core_);
context_->subscribe([contextWeak = std::move(contextWeak)] {
if (auto context = contextWeak.lock()) {
context->update();
}
});
// Do an extra update in case observable was updated between observer creation
// and setting updates callback.
context_->update();
context_.reset();
DCHECK(core->getVersion() > 0);
return Observer<T>(std::move(core));
return observer;
}
}
}
......@@ -38,10 +38,10 @@ Observer<observer_detail::ResultOfUnwrapSharedPtr<F>> makeObserver(
F&& creator) {
auto core = observer_detail::Core::
create([creator = std::forward<F>(creator)]() mutable {
return std::static_pointer_cast<void>(creator());
return std::static_pointer_cast<const void>(creator());
});
observer_detail::ObserverManager::scheduleRefreshNewVersion(core);
observer_detail::ObserverManager::initCore(core);
return Observer<observer_detail::ResultOfUnwrapSharedPtr<F>>(core);
}
......
......@@ -134,6 +134,9 @@ class Observer {
}
private:
template <typename Observable, typename Traits>
friend class ObserverCreator;
observer_detail::Core::Ptr core_;
};
......
......@@ -106,28 +106,35 @@ class ObserverManager::NextQueue {
explicit NextQueue(ObserverManager& manager)
: manager_(manager), queue_(kNextQueueSize) {
thread_ = std::thread([&]() {
Core::Ptr queueCore;
Core::WeakPtr queueCoreWeak;
while (true) {
queue_.blockingRead(queueCore);
if (!queueCore) {
queue_.blockingRead(queueCoreWeak);
if (stop_) {
return;
}
std::vector<Core::Ptr> cores;
cores.emplace_back(std::move(queueCore));
{
auto queueCore = queueCoreWeak.lock();
if (!queueCore) {
continue;
}
cores.emplace_back(std::move(queueCore));
}
{
SharedMutexReadPriority::WriteHolder wh(manager_.versionMutex_);
// We can't pick more tasks from the queue after we bumped the
// version, so we have to do this while holding the lock.
while (cores.size() < kNextQueueSize && queue_.read(queueCore)) {
if (!queueCore) {
while (cores.size() < kNextQueueSize && queue_.read(queueCoreWeak)) {
if (stop_) {
return;
}
cores.emplace_back(std::move(queueCore));
if (auto queueCore = queueCoreWeak.lock()) {
cores.emplace_back(std::move(queueCore));
}
}
++manager_.version_;
......@@ -140,20 +147,22 @@ class ObserverManager::NextQueue {
});
}
void add(Core::Ptr core) {
void add(Core::WeakPtr core) {
queue_.blockingWrite(std::move(core));
}
~NextQueue() {
// Emtpy element signals thread to terminate
queue_.blockingWrite(nullptr);
stop_ = true;
// Write to the queue to notify the thread.
queue_.blockingWrite(Core::WeakPtr());
thread_.join();
}
private:
ObserverManager& manager_;
MPMCQueue<Core::Ptr> queue_;
MPMCQueue<Core::WeakPtr> queue_;
std::thread thread_;
std::atomic<bool> stop_{false};
};
ObserverManager::ObserverManager() {
......@@ -172,7 +181,7 @@ void ObserverManager::scheduleCurrent(Function<void()> task) {
currentQueue_->add(std::move(task));
}
void ObserverManager::scheduleNext(Core::Ptr core) {
void ObserverManager::scheduleNext(Core::WeakPtr core) {
nextQueue_->add(std::move(core));
}
......
......@@ -93,19 +93,19 @@ class ObserverManager {
return future;
}
static void scheduleRefreshNewVersion(Core::Ptr core) {
if (core->getVersion() == 0) {
scheduleRefresh(std::move(core), 1).get();
return;
}
static void scheduleRefreshNewVersion(Core::WeakPtr coreWeak) {
auto instance = getInstance();
if (!instance) {
return;
}
instance->scheduleNext(std::move(core));
instance->scheduleNext(std::move(coreWeak));
}
static void initCore(Core::Ptr core) {
DCHECK(core->getVersion() == 0);
scheduleRefresh(std::move(core), 1).get();
}
class DependencyRecorder {
......@@ -189,7 +189,7 @@ class ObserverManager {
struct Singleton;
void scheduleCurrent(Function<void()>);
void scheduleNext(Core::Ptr);
void scheduleNext(Core::WeakPtr);
class CurrentQueue;
class NextQueue;
......
......@@ -262,3 +262,62 @@ TEST(Observer, TLObserver) {
k = std::make_unique<folly::observer::TLObserver<int>>(createTLObserver(41));
EXPECT_EQ(41, ***k);
}
TEST(Observer, SubscribeCallback) {
static auto mainThreadId = std::this_thread::get_id();
static std::function<void()> updatesCob;
static bool slowGet = false;
static std::atomic<size_t> getCallsStart{0};
static std::atomic<size_t> getCallsFinish{0};
struct Observable {
~Observable() {
EXPECT_EQ(mainThreadId, std::this_thread::get_id());
}
};
struct Traits {
using element_type = int;
static std::shared_ptr<const int> get(Observable&) {
++getCallsStart;
if (slowGet) {
/* sleep override */ std::this_thread::sleep_for(
std::chrono::seconds{2});
}
++getCallsFinish;
return std::make_shared<const int>(42);
}
static void subscribe(Observable&, std::function<void()> cob) {
updatesCob = std::move(cob);
}
static void unsubscribe(Observable&) {}
};
std::thread cobThread;
{
auto observer =
folly::observer::ObserverCreator<Observable, Traits>().getObserver();
EXPECT_TRUE(updatesCob);
EXPECT_EQ(2, getCallsStart);
EXPECT_EQ(2, getCallsFinish);
updatesCob();
EXPECT_EQ(3, getCallsStart);
EXPECT_EQ(3, getCallsFinish);
slowGet = true;
cobThread = std::thread([] { updatesCob(); });
/* sleep override */ std::this_thread::sleep_for(std::chrono::seconds{1});
EXPECT_EQ(4, getCallsStart);
EXPECT_EQ(3, getCallsFinish);
// Observer is destroyed here
}
// Make sure that destroying the observer actually joined the updates callback
EXPECT_EQ(4, getCallsStart);
EXPECT_EQ(4, getCallsFinish);
cobThread.join();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment