Commit 2642bd3d authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Fix a race in Observable context destruction

Summary: In the subscribe callback It's possible that we lock the Context shared_ptr and while update is running, all other shared_ptr's are released. This will result in Context to be destroyed from the wrong thread (thread runnning subcribe callback), which is not desired.

Reviewed By: yfeldblum

Differential Revision: D4964605

fbshipit-source-id: 285327a6873ccb7393fa3067ba7e612c29dbc454
parent 79cb7776
...@@ -22,7 +22,9 @@ template <typename Observable, typename Traits> ...@@ -22,7 +22,9 @@ template <typename Observable, typename Traits>
class ObserverCreator<Observable, Traits>::Context { class ObserverCreator<Observable, Traits>::Context {
public: public:
template <typename... Args> template <typename... Args>
Context(Args&&... args) : observable_(std::forward<Args>(args)...) {} Context(Args&&... args) : observable_(std::forward<Args>(args)...) {
updateValue();
}
~Context() { ~Context() {
if (value_.copy()) { if (value_.copy()) {
...@@ -47,21 +49,11 @@ class ObserverCreator<Observable, Traits>::Context { ...@@ -47,21 +49,11 @@ class ObserverCreator<Observable, Traits>::Context {
// callbacks (getting new value from observable and storing it into value_ // callbacks (getting new value from observable and storing it into value_
// is not atomic). // is not atomic).
std::lock_guard<std::mutex> lg(updateMutex_); std::lock_guard<std::mutex> lg(updateMutex_);
updateValue();
{
auto newValue = Traits::get(observable_);
if (!newValue) {
throw std::logic_error("Observable returned nullptr.");
}
value_.swap(newValue);
}
bool expected = false; bool expected = false;
if (updateRequested_.compare_exchange_strong(expected, true)) { if (updateRequested_.compare_exchange_strong(expected, true)) {
if (auto core = coreWeak_.lock()) { observer_detail::ObserverManager::scheduleRefreshNewVersion(coreWeak_);
observer_detail::ObserverManager::scheduleRefreshNewVersion(
std::move(core));
}
} }
} }
...@@ -71,6 +63,14 @@ class ObserverCreator<Observable, Traits>::Context { ...@@ -71,6 +63,14 @@ class ObserverCreator<Observable, Traits>::Context {
} }
private: private:
void updateValue() {
auto newValue = Traits::get(observable_);
if (!newValue) {
throw std::logic_error("Observable returned nullptr.");
}
value_.swap(newValue);
}
folly::Synchronized<std::shared_ptr<const T>> value_; folly::Synchronized<std::shared_ptr<const T>> value_;
std::atomic<bool> updateRequested_{false}; std::atomic<bool> updateRequested_{false};
...@@ -89,24 +89,68 @@ ObserverCreator<Observable, Traits>::ObserverCreator(Args&&... args) ...@@ -89,24 +89,68 @@ ObserverCreator<Observable, Traits>::ObserverCreator(Args&&... args)
template <typename Observable, typename Traits> template <typename Observable, typename Traits>
Observer<typename ObserverCreator<Observable, Traits>::T> Observer<typename ObserverCreator<Observable, Traits>::T>
ObserverCreator<Observable, Traits>::getObserver()&& { ObserverCreator<Observable, Traits>::getObserver()&& {
auto core = observer_detail::Core::create([context = context_]() { // This master shared_ptr allows grabbing derived weak_ptrs, pointing to the
// the same Context object, but using a separate reference count. Master
// shared_ptr destructor then blocks until all shared_ptrs obtained from
// derived weak_ptrs are released.
class ContextMasterPointer {
public:
explicit ContextMasterPointer(std::shared_ptr<Context> context)
: contextMaster_(std::move(context)),
context_(
contextMaster_.get(),
[destroyBaton = destroyBaton_](Context*) {
destroyBaton->post();
}) {}
~ContextMasterPointer() {
if (context_) {
context_.reset();
destroyBaton_->wait();
}
}
ContextMasterPointer(const ContextMasterPointer&) = delete;
ContextMasterPointer(ContextMasterPointer&&) = default;
ContextMasterPointer& operator=(const ContextMasterPointer&) = delete;
ContextMasterPointer& operator=(ContextMasterPointer&&) = default;
Context* operator->() const {
return contextMaster_.get();
}
std::weak_ptr<Context> get_weak() {
return context_;
}
private:
std::shared_ptr<folly::Baton<>> destroyBaton_{
std::make_shared<folly::Baton<>>()};
std::shared_ptr<Context> contextMaster_;
std::shared_ptr<Context> context_;
};
// We want to make sure that Context can only be destroyed when Core is
// destroyed. So we have to avoid the situation when subscribe callback is
// locking Context shared_ptr and remains the last to release it.
// We solve this by having Core hold the master shared_ptr and subscription
// callback gets derived weak_ptr.
ContextMasterPointer contextMaster(context_);
auto contextWeak = contextMaster.get_weak();
auto observer = makeObserver([context = std::move(contextMaster)]() {
return context->get(); return context->get();
}); });
context_->setCore(core); context_->setCore(observer.core_);
context_->subscribe([contextWeak = std::move(contextWeak)] {
context_->subscribe([contextWeak = std::weak_ptr<Context>(context_)] {
if (auto context = contextWeak.lock()) { if (auto context = contextWeak.lock()) {
context->update(); context->update();
} }
}); });
// Do an extra update in case observable was updated between observer creation
// and setting updates callback.
context_->update(); context_->update();
context_.reset(); context_.reset();
DCHECK(core->getVersion() > 0); return observer;
return Observer<T>(std::move(core));
} }
} }
} }
...@@ -38,10 +38,10 @@ Observer<observer_detail::ResultOfUnwrapSharedPtr<F>> makeObserver( ...@@ -38,10 +38,10 @@ Observer<observer_detail::ResultOfUnwrapSharedPtr<F>> makeObserver(
F&& creator) { F&& creator) {
auto core = observer_detail::Core:: auto core = observer_detail::Core::
create([creator = std::forward<F>(creator)]() mutable { create([creator = std::forward<F>(creator)]() mutable {
return std::static_pointer_cast<void>(creator()); return std::static_pointer_cast<const void>(creator());
}); });
observer_detail::ObserverManager::scheduleRefreshNewVersion(core); observer_detail::ObserverManager::initCore(core);
return Observer<observer_detail::ResultOfUnwrapSharedPtr<F>>(core); return Observer<observer_detail::ResultOfUnwrapSharedPtr<F>>(core);
} }
......
...@@ -134,6 +134,9 @@ class Observer { ...@@ -134,6 +134,9 @@ class Observer {
} }
private: private:
template <typename Observable, typename Traits>
friend class ObserverCreator;
observer_detail::Core::Ptr core_; observer_detail::Core::Ptr core_;
}; };
......
...@@ -106,28 +106,35 @@ class ObserverManager::NextQueue { ...@@ -106,28 +106,35 @@ class ObserverManager::NextQueue {
explicit NextQueue(ObserverManager& manager) explicit NextQueue(ObserverManager& manager)
: manager_(manager), queue_(kNextQueueSize) { : manager_(manager), queue_(kNextQueueSize) {
thread_ = std::thread([&]() { thread_ = std::thread([&]() {
Core::Ptr queueCore; Core::WeakPtr queueCoreWeak;
while (true) { while (true) {
queue_.blockingRead(queueCore); queue_.blockingRead(queueCoreWeak);
if (stop_) {
if (!queueCore) {
return; return;
} }
std::vector<Core::Ptr> cores; std::vector<Core::Ptr> cores;
cores.emplace_back(std::move(queueCore)); {
auto queueCore = queueCoreWeak.lock();
if (!queueCore) {
continue;
}
cores.emplace_back(std::move(queueCore));
}
{ {
SharedMutexReadPriority::WriteHolder wh(manager_.versionMutex_); SharedMutexReadPriority::WriteHolder wh(manager_.versionMutex_);
// We can't pick more tasks from the queue after we bumped the // We can't pick more tasks from the queue after we bumped the
// version, so we have to do this while holding the lock. // version, so we have to do this while holding the lock.
while (cores.size() < kNextQueueSize && queue_.read(queueCore)) { while (cores.size() < kNextQueueSize && queue_.read(queueCoreWeak)) {
if (!queueCore) { if (stop_) {
return; return;
} }
cores.emplace_back(std::move(queueCore)); if (auto queueCore = queueCoreWeak.lock()) {
cores.emplace_back(std::move(queueCore));
}
} }
++manager_.version_; ++manager_.version_;
...@@ -140,20 +147,22 @@ class ObserverManager::NextQueue { ...@@ -140,20 +147,22 @@ class ObserverManager::NextQueue {
}); });
} }
void add(Core::Ptr core) { void add(Core::WeakPtr core) {
queue_.blockingWrite(std::move(core)); queue_.blockingWrite(std::move(core));
} }
~NextQueue() { ~NextQueue() {
// Emtpy element signals thread to terminate stop_ = true;
queue_.blockingWrite(nullptr); // Write to the queue to notify the thread.
queue_.blockingWrite(Core::WeakPtr());
thread_.join(); thread_.join();
} }
private: private:
ObserverManager& manager_; ObserverManager& manager_;
MPMCQueue<Core::Ptr> queue_; MPMCQueue<Core::WeakPtr> queue_;
std::thread thread_; std::thread thread_;
std::atomic<bool> stop_{false};
}; };
ObserverManager::ObserverManager() { ObserverManager::ObserverManager() {
...@@ -172,7 +181,7 @@ void ObserverManager::scheduleCurrent(Function<void()> task) { ...@@ -172,7 +181,7 @@ void ObserverManager::scheduleCurrent(Function<void()> task) {
currentQueue_->add(std::move(task)); currentQueue_->add(std::move(task));
} }
void ObserverManager::scheduleNext(Core::Ptr core) { void ObserverManager::scheduleNext(Core::WeakPtr core) {
nextQueue_->add(std::move(core)); nextQueue_->add(std::move(core));
} }
......
...@@ -93,19 +93,19 @@ class ObserverManager { ...@@ -93,19 +93,19 @@ class ObserverManager {
return future; return future;
} }
static void scheduleRefreshNewVersion(Core::Ptr core) { static void scheduleRefreshNewVersion(Core::WeakPtr coreWeak) {
if (core->getVersion() == 0) {
scheduleRefresh(std::move(core), 1).get();
return;
}
auto instance = getInstance(); auto instance = getInstance();
if (!instance) { if (!instance) {
return; return;
} }
instance->scheduleNext(std::move(core)); instance->scheduleNext(std::move(coreWeak));
}
static void initCore(Core::Ptr core) {
DCHECK(core->getVersion() == 0);
scheduleRefresh(std::move(core), 1).get();
} }
class DependencyRecorder { class DependencyRecorder {
...@@ -189,7 +189,7 @@ class ObserverManager { ...@@ -189,7 +189,7 @@ class ObserverManager {
struct Singleton; struct Singleton;
void scheduleCurrent(Function<void()>); void scheduleCurrent(Function<void()>);
void scheduleNext(Core::Ptr); void scheduleNext(Core::WeakPtr);
class CurrentQueue; class CurrentQueue;
class NextQueue; class NextQueue;
......
...@@ -262,3 +262,62 @@ TEST(Observer, TLObserver) { ...@@ -262,3 +262,62 @@ TEST(Observer, TLObserver) {
k = std::make_unique<folly::observer::TLObserver<int>>(createTLObserver(41)); k = std::make_unique<folly::observer::TLObserver<int>>(createTLObserver(41));
EXPECT_EQ(41, ***k); EXPECT_EQ(41, ***k);
} }
TEST(Observer, SubscribeCallback) {
static auto mainThreadId = std::this_thread::get_id();
static std::function<void()> updatesCob;
static bool slowGet = false;
static std::atomic<size_t> getCallsStart{0};
static std::atomic<size_t> getCallsFinish{0};
struct Observable {
~Observable() {
EXPECT_EQ(mainThreadId, std::this_thread::get_id());
}
};
struct Traits {
using element_type = int;
static std::shared_ptr<const int> get(Observable&) {
++getCallsStart;
if (slowGet) {
/* sleep override */ std::this_thread::sleep_for(
std::chrono::seconds{2});
}
++getCallsFinish;
return std::make_shared<const int>(42);
}
static void subscribe(Observable&, std::function<void()> cob) {
updatesCob = std::move(cob);
}
static void unsubscribe(Observable&) {}
};
std::thread cobThread;
{
auto observer =
folly::observer::ObserverCreator<Observable, Traits>().getObserver();
EXPECT_TRUE(updatesCob);
EXPECT_EQ(2, getCallsStart);
EXPECT_EQ(2, getCallsFinish);
updatesCob();
EXPECT_EQ(3, getCallsStart);
EXPECT_EQ(3, getCallsFinish);
slowGet = true;
cobThread = std::thread([] { updatesCob(); });
/* sleep override */ std::this_thread::sleep_for(std::chrono::seconds{1});
EXPECT_EQ(4, getCallsStart);
EXPECT_EQ(3, getCallsFinish);
// Observer is destroyed here
}
// Make sure that destroying the observer actually joined the updates callback
EXPECT_EQ(4, getCallsStart);
EXPECT_EQ(4, getCallsFinish);
cobThread.join();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment