Commit 9b8488e8 authored by Petr Lapukhov's avatar Petr Lapukhov Committed by Facebook Github Bot 3

Allow accept callbacks to be short-circuited in primary event-base

Summary:
It looks like we were effectively avoiding short-circuiting callbacks submitted for execution in primary event-base (evb == nulptr). The check was there, but it was never effective, since on `addAcceptCallback` we would mask the `nullptr` with our event base pointer.

I see two ways to fix that: either modify the check

    if (info->eventBase == nullptr) { ...} on line 834

to compare to the presently attached event base, or store `eventBase = nullptr` into callbacks_ list (CallbackInfo struct). The second approach requires more changes (implemented here) but allows the caller to still submit callbacks for execution via notification queue event in primary event base by supplying eventBase parameter != nullptr in addAcceptCallback. I therefore chose the second approach.

The existing unit-tests needed modification to avoid using the "broken" nullptr semantics (most cases were assuming it would be using notification queue signaling). I quickly looked at fbcode, and it looks like we only have a few cases of addAcceptCallback() with nullptr, the unit-tests for those are passing.

NOTE: The removeAcceptCallback() semantics is different with regards to eventBase; nullptr here means "scan all callbacks regardless of event-base they belong to".

Reviewed By: djwatson

Differential Revision: D3714697

fbshipit-source-id: 2362bcff86a7e0604914b1cb7f1471fe4d03e78e
parent 12ace861
...@@ -125,7 +125,7 @@ class AsyncServerSocket::BackoffTimeout : public AsyncTimeout { ...@@ -125,7 +125,7 @@ class AsyncServerSocket::BackoffTimeout : public AsyncTimeout {
public: public:
// Disallow copy, move, and default constructors. // Disallow copy, move, and default constructors.
BackoffTimeout(BackoffTimeout&&) = delete; BackoffTimeout(BackoffTimeout&&) = delete;
BackoffTimeout(AsyncServerSocket* socket) explicit BackoffTimeout(AsyncServerSocket* socket)
: AsyncTimeout(socket->getEventBase()), socket_(socket) {} : AsyncTimeout(socket->getEventBase()), socket_(socket) {}
void timeoutExpired() noexcept override { socket_->backoffTimeoutExpired(); } void timeoutExpired() noexcept override { socket_->backoffTimeoutExpired(); }
...@@ -219,7 +219,14 @@ int AsyncServerSocket::stopAccepting(int shutdownFlags) { ...@@ -219,7 +219,14 @@ int AsyncServerSocket::stopAccepting(int shutdownFlags) {
for (std::vector<CallbackInfo>::iterator it = callbacksCopy.begin(); for (std::vector<CallbackInfo>::iterator it = callbacksCopy.begin();
it != callbacksCopy.end(); it != callbacksCopy.end();
++it) { ++it) {
it->consumer->stop(it->eventBase, it->callback); // consumer may not be set if we are running in primary event base
if (it->consumer) {
DCHECK(it->eventBase);
it->consumer->stop(it->eventBase, it->callback);
} else {
DCHECK(it->callback);
it->callback->acceptStopped();
}
} }
return result; return result;
...@@ -513,12 +520,23 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback, ...@@ -513,12 +520,23 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
// start accepting once the callback is installed. // start accepting once the callback is installed.
bool runStartAccepting = accepting_ && callbacks_.empty(); bool runStartAccepting = accepting_ && callbacks_.empty();
callbacks_.emplace_back(callback, eventBase);
SCOPE_SUCCESS {
// If this is the first accept callback and we are supposed to be accepting,
// start accepting.
if (runStartAccepting) {
startAccepting();
}
};
if (!eventBase) { if (!eventBase) {
eventBase = eventBase_; // Run in AsyncServerSocket's eventbase // Run in AsyncServerSocket's eventbase; notify that we are
// starting to accept connections
callback->acceptStarted();
return;
} }
callbacks_.emplace_back(callback, eventBase);
// Start the remote acceptor. // Start the remote acceptor.
// //
// It would be nice if we could avoid starting the remote acceptor if // It would be nice if we could avoid starting the remote acceptor if
...@@ -538,12 +556,6 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback, ...@@ -538,12 +556,6 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
throw; throw;
} }
callbacks_.back().consumer = acceptor; callbacks_.back().consumer = acceptor;
// If this is the first accept callback and we are supposed to be accepting,
// start accepting.
if (runStartAccepting) {
startAccepting();
}
} }
void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback, void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback,
...@@ -590,7 +602,16 @@ void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback, ...@@ -590,7 +602,16 @@ void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback,
} }
} }
info.consumer->stop(info.eventBase, info.callback); if (info.consumer) {
// consumer could be nullptr is we run callbacks in primary event
// base
DCHECK(info.eventBase);
info.consumer->stop(info.eventBase, info.callback);
} else {
// callback invoked in the primary event base, just call directly
DCHECK(info.callback);
callback->acceptStopped();
}
// If we are supposed to be accepting but the last accept callback // If we are supposed to be accepting but the last accept callback
// was removed, unregister for events until a callback is added. // was removed, unregister for events until a callback is added.
......
...@@ -1733,12 +1733,12 @@ TEST(AsyncSocketTest, ServerAcceptOptions) { ...@@ -1733,12 +1733,12 @@ TEST(AsyncSocketTest, ServerAcceptOptions) {
TestAcceptCallback acceptCallback; TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn( acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) { [&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
}); });
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
}); });
serverSocket->addAcceptCallback(&acceptCallback, nullptr); serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
serverSocket->startAccepting(); serverSocket->startAccepting();
// Connect to the server socket // Connect to the server socket
...@@ -1850,13 +1850,13 @@ TEST(AsyncSocketTest, RemoveAcceptCallback) { ...@@ -1850,13 +1850,13 @@ TEST(AsyncSocketTest, RemoveAcceptCallback) {
serverSocket->removeAcceptCallback(&cb7, nullptr); serverSocket->removeAcceptCallback(&cb7, nullptr);
}); });
serverSocket->addAcceptCallback(&cb1, nullptr); serverSocket->addAcceptCallback(&cb1, &eventBase);
serverSocket->addAcceptCallback(&cb2, nullptr); serverSocket->addAcceptCallback(&cb2, &eventBase);
serverSocket->addAcceptCallback(&cb3, nullptr); serverSocket->addAcceptCallback(&cb3, &eventBase);
serverSocket->addAcceptCallback(&cb4, nullptr); serverSocket->addAcceptCallback(&cb4, &eventBase);
serverSocket->addAcceptCallback(&cb5, nullptr); serverSocket->addAcceptCallback(&cb5, &eventBase);
serverSocket->addAcceptCallback(&cb6, nullptr); serverSocket->addAcceptCallback(&cb6, &eventBase);
serverSocket->addAcceptCallback(&cb7, nullptr); serverSocket->addAcceptCallback(&cb7, &eventBase);
serverSocket->startAccepting(); serverSocket->startAccepting();
// Make several connections to the socket // Make several connections to the socket
...@@ -1959,14 +1959,14 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) { ...@@ -1959,14 +1959,14 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
cb1.setConnectionAcceptedFn( cb1.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) { [&](int /* fd */, const folly::SocketAddress& /* addr */) {
CHECK_EQ(thread_id, std::this_thread::get_id()); CHECK_EQ(thread_id, std::this_thread::get_id());
serverSocket->removeAcceptCallback(&cb1, nullptr); serverSocket->removeAcceptCallback(&cb1, &eventBase);
}); });
cb1.setAcceptStoppedFn([&](){ cb1.setAcceptStoppedFn([&](){
CHECK_EQ(thread_id, std::this_thread::get_id()); CHECK_EQ(thread_id, std::this_thread::get_id());
}); });
// Test having callbacks remove other callbacks before them on the list, // Test having callbacks remove other callbacks before them on the list,
serverSocket->addAcceptCallback(&cb1, nullptr); serverSocket->addAcceptCallback(&cb1, &eventBase);
serverSocket->startAccepting(); serverSocket->startAccepting();
// Make several connections to the socket // Make several connections to the socket
...@@ -1999,20 +1999,22 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) { ...@@ -1999,20 +1999,22 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
} }
void serverSocketSanityTest(AsyncServerSocket* serverSocket) { void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
EventBase* eventBase = serverSocket->getEventBase();
CHECK(eventBase);
// Add a callback to accept one connection then stop accepting // Add a callback to accept one connection then stop accepting
TestAcceptCallback acceptCallback; TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn( acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) { [&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
}); });
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
}); });
serverSocket->addAcceptCallback(&acceptCallback, nullptr); serverSocket->addAcceptCallback(&acceptCallback, eventBase);
serverSocket->startAccepting(); serverSocket->startAccepting();
// Connect to the server socket // Connect to the server socket
EventBase* eventBase = serverSocket->getEventBase();
folly::SocketAddress serverAddress; folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress); serverSocket->getAddress(&serverAddress);
AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress)); AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
...@@ -2181,12 +2183,12 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) { ...@@ -2181,12 +2183,12 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
TestAcceptCallback acceptCallback; TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn( acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) { [&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
}); });
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
}); });
serverSocket->addAcceptCallback(&acceptCallback, nullptr); serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
serverSocket->startAccepting(); serverSocket->startAccepting();
// Connect to the server socket // Connect to the server socket
...@@ -2232,7 +2234,7 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) { ...@@ -2232,7 +2234,7 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr); serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
}); });
serverSocket->addAcceptCallback(&acceptCallback, nullptr); serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
serverSocket->startAccepting(); serverSocket->startAccepting();
// Connect to the server socket // Connect to the server socket
...@@ -2253,6 +2255,61 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) { ...@@ -2253,6 +2255,61 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0); ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
} }
TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
EventBase eventBase;
TestConnectionEventCallback connectionEventCallback;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->setConnectionEventCallback(&connectionEventCallback);
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
bool acceptStartedFlag{false};
acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
acceptStartedFlag = true;
});
bool acceptStoppedFlag{false};
acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
acceptStoppedFlag = true;
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
ASSERT_TRUE(acceptStartedFlag);
ASSERT_TRUE(acceptStoppedFlag);
// Validate the connection event counters
ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
ASSERT_EQ(
connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}
/** /**
* Test AsyncServerSocket::getNumPendingMessagesInQueue() * Test AsyncServerSocket::getNumPendingMessagesInQueue()
*/ */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment