Commit 7765d9b4 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook GitHub Bot

Avoid crashes if Observable::get throws.

Reviewed By: praihan

Differential Revision: D32907070

fbshipit-source-id: 0341ad85fe59c5bbea2a54ed31adf76949f82604
parent 7685dc55
...@@ -48,30 +48,37 @@ class ObserverCreatorContext { ...@@ -48,30 +48,37 @@ class ObserverCreatorContext {
return state->value; return state->value;
} }
observer_detail::Core::Ptr update() { observer_detail::Core::Ptr update() noexcept {
// This mutex ensures there's no race condition between initial update() try {
// call and update() calls from the subsciption callback. // This mutex ensures there's no race condition between initial update()
// // call and update() calls from the subsciption callback.
// Additionally it helps avoid races between two different subscription //
// callbacks (getting new value from observable and storing it into value_ // Additionally it helps avoid races between two different subscription
// is not atomic). // callbacks (getting new value from observable and storing it into value_
// // is not atomic).
// Note that state_ lock is acquired only after Traits::get. Traits::get //
// is running application code (that may acquire locks) and so it's // Note that state_ lock is acquired only after Traits::get. Traits::get
// important to not hold state_ lock while running it to avoid possible lock // is running application code (that may acquire locks) and so it's
// inversion with another code path that needs state_ lock (e.g. get()). // important to not hold state_ lock while running it to avoid possible
std::lock_guard<std::mutex> updateLockGuard(updateLock_); // lock inversion with another code path that needs state_ lock (e.g.
auto newValue = Traits::get(observable_); // get()).
std::lock_guard<std::mutex> updateLockGuard(updateLock_);
auto newValue = Traits::get(observable_);
auto state = state_.lock();
if (!state->updateValue(std::move(newValue))) {
// Value didn't change, so we can skip the version update.
return nullptr;
}
auto state = state_.lock(); if (!std::exchange(state->updateRequested, true)) {
if (!state->updateValue(std::move(newValue))) { return coreWeak_.lock();
// Value didn't change, so we can skip the version update. }
return nullptr; } catch (...) {
LOG(ERROR) << "Observer update failed: "
<< folly::exceptionStr(std::current_exception());
} }
if (!std::exchange(state->updateRequested, true)) {
return coreWeak_.lock();
}
return nullptr; return nullptr;
} }
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include <stdexcept>
#include <thread> #include <thread>
#include <utility>
#include <folly/Singleton.h> #include <folly/Singleton.h>
#include <folly/experimental/observer/Observer.h> #include <folly/experimental/observer/Observer.h>
#include <folly/experimental/observer/SimpleObservable.h> #include <folly/experimental/observer/SimpleObservable.h>
...@@ -960,3 +962,54 @@ TEST(Observer, ObservableLockInversion) { ...@@ -960,3 +962,54 @@ TEST(Observer, ObservableLockInversion) {
updater.join(); updater.join();
} }
folly::Function<void()> throwingObservableCallback;
TEST(Observer, ObservableGetThrow) {
struct ThrowingObservable {
using element_type = size_t;
std::shared_ptr<const size_t> get() {
if (getCalled_.exchange(true)) {
throw std::logic_error("Transient error");
}
return std::make_shared<const size_t>(42);
}
void subscribe(folly::Function<void()> cb) {
throwingObservableCallback = std::move(cb);
}
void unsubscribe() { throwingObservableCallback = nullptr; }
private:
std::atomic<bool> getCalled_{false};
};
auto observer =
folly::observer::ObserverCreator<ThrowingObservable>().getObserver();
EXPECT_EQ(42, **observer);
folly::observer_detail::ObserverManager::waitForAllUpdates();
EXPECT_EQ(42, **observer);
throwingObservableCallback();
folly::observer_detail::ObserverManager::waitForAllUpdates();
EXPECT_EQ(42, **observer);
struct ExpectedException {};
struct AlwaysThrowingObservable {
using element_type = size_t;
std::shared_ptr<const size_t> get() { throw ExpectedException(); }
void subscribe(folly::Function<void()>) {}
void unsubscribe() {}
};
EXPECT_THROW(
folly::observer::ObserverCreator<AlwaysThrowingObservable>()
.getObserver(),
ExpectedException);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment