Commit 9b8488e8 authored by Petr Lapukhov's avatar Petr Lapukhov Committed by Facebook Github Bot 3

Allow accept callbacks to be short-circuited in primary event-base

Summary:
It looks like we were effectively avoiding short-circuiting callbacks submitted for execution in primary event-base (evb == nulptr). The check was there, but it was never effective, since on `addAcceptCallback` we would mask the `nullptr` with our event base pointer.

I see two ways to fix that: either modify the check

    if (info->eventBase == nullptr) { ...} on line 834

to compare to the presently attached event base, or store `eventBase = nullptr` into callbacks_ list (CallbackInfo struct). The second approach requires more changes (implemented here) but allows the caller to still submit callbacks for execution via notification queue event in primary event base by supplying eventBase parameter != nullptr in addAcceptCallback. I therefore chose the second approach.

The existing unit-tests needed modification to avoid using the "broken" nullptr semantics (most cases were assuming it would be using notification queue signaling). I quickly looked at fbcode, and it looks like we only have a few cases of addAcceptCallback() with nullptr, the unit-tests for those are passing.

NOTE: The removeAcceptCallback() semantics is different with regards to eventBase; nullptr here means "scan all callbacks regardless of event-base they belong to".

Reviewed By: djwatson

Differential Revision: D3714697

fbshipit-source-id: 2362bcff86a7e0604914b1cb7f1471fe4d03e78e
parent 12ace861
......@@ -125,7 +125,7 @@ class AsyncServerSocket::BackoffTimeout : public AsyncTimeout {
public:
// Disallow copy, move, and default constructors.
BackoffTimeout(BackoffTimeout&&) = delete;
BackoffTimeout(AsyncServerSocket* socket)
explicit BackoffTimeout(AsyncServerSocket* socket)
: AsyncTimeout(socket->getEventBase()), socket_(socket) {}
void timeoutExpired() noexcept override { socket_->backoffTimeoutExpired(); }
......@@ -219,7 +219,14 @@ int AsyncServerSocket::stopAccepting(int shutdownFlags) {
for (std::vector<CallbackInfo>::iterator it = callbacksCopy.begin();
it != callbacksCopy.end();
++it) {
it->consumer->stop(it->eventBase, it->callback);
// consumer may not be set if we are running in primary event base
if (it->consumer) {
DCHECK(it->eventBase);
it->consumer->stop(it->eventBase, it->callback);
} else {
DCHECK(it->callback);
it->callback->acceptStopped();
}
}
return result;
......@@ -513,12 +520,23 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
// start accepting once the callback is installed.
bool runStartAccepting = accepting_ && callbacks_.empty();
callbacks_.emplace_back(callback, eventBase);
SCOPE_SUCCESS {
// If this is the first accept callback and we are supposed to be accepting,
// start accepting.
if (runStartAccepting) {
startAccepting();
}
};
if (!eventBase) {
eventBase = eventBase_; // Run in AsyncServerSocket's eventbase
// Run in AsyncServerSocket's eventbase; notify that we are
// starting to accept connections
callback->acceptStarted();
return;
}
callbacks_.emplace_back(callback, eventBase);
// Start the remote acceptor.
//
// It would be nice if we could avoid starting the remote acceptor if
......@@ -538,12 +556,6 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
throw;
}
callbacks_.back().consumer = acceptor;
// If this is the first accept callback and we are supposed to be accepting,
// start accepting.
if (runStartAccepting) {
startAccepting();
}
}
void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback,
......@@ -590,7 +602,16 @@ void AsyncServerSocket::removeAcceptCallback(AcceptCallback *callback,
}
}
info.consumer->stop(info.eventBase, info.callback);
if (info.consumer) {
// consumer could be nullptr is we run callbacks in primary event
// base
DCHECK(info.eventBase);
info.consumer->stop(info.eventBase, info.callback);
} else {
// callback invoked in the primary event base, just call directly
DCHECK(info.callback);
callback->acceptStopped();
}
// If we are supposed to be accepting but the last accept callback
// was removed, unregister for events until a callback is added.
......
......@@ -1733,12 +1733,12 @@ TEST(AsyncSocketTest, ServerAcceptOptions) {
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
serverSocket->startAccepting();
// Connect to the server socket
......@@ -1850,13 +1850,13 @@ TEST(AsyncSocketTest, RemoveAcceptCallback) {
serverSocket->removeAcceptCallback(&cb7, nullptr);
});
serverSocket->addAcceptCallback(&cb1, nullptr);
serverSocket->addAcceptCallback(&cb2, nullptr);
serverSocket->addAcceptCallback(&cb3, nullptr);
serverSocket->addAcceptCallback(&cb4, nullptr);
serverSocket->addAcceptCallback(&cb5, nullptr);
serverSocket->addAcceptCallback(&cb6, nullptr);
serverSocket->addAcceptCallback(&cb7, nullptr);
serverSocket->addAcceptCallback(&cb1, &eventBase);
serverSocket->addAcceptCallback(&cb2, &eventBase);
serverSocket->addAcceptCallback(&cb3, &eventBase);
serverSocket->addAcceptCallback(&cb4, &eventBase);
serverSocket->addAcceptCallback(&cb5, &eventBase);
serverSocket->addAcceptCallback(&cb6, &eventBase);
serverSocket->addAcceptCallback(&cb7, &eventBase);
serverSocket->startAccepting();
// Make several connections to the socket
......@@ -1959,14 +1959,14 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
cb1.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
CHECK_EQ(thread_id, std::this_thread::get_id());
serverSocket->removeAcceptCallback(&cb1, nullptr);
serverSocket->removeAcceptCallback(&cb1, &eventBase);
});
cb1.setAcceptStoppedFn([&](){
CHECK_EQ(thread_id, std::this_thread::get_id());
});
// Test having callbacks remove other callbacks before them on the list,
serverSocket->addAcceptCallback(&cb1, nullptr);
serverSocket->addAcceptCallback(&cb1, &eventBase);
serverSocket->startAccepting();
// Make several connections to the socket
......@@ -1999,20 +1999,22 @@ TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
}
void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
EventBase* eventBase = serverSocket->getEventBase();
CHECK(eventBase);
// Add a callback to accept one connection then stop accepting
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->addAcceptCallback(&acceptCallback, eventBase);
serverSocket->startAccepting();
// Connect to the server socket
EventBase* eventBase = serverSocket->getEventBase();
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
......@@ -2181,12 +2183,12 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
serverSocket->startAccepting();
// Connect to the server socket
......@@ -2232,7 +2234,7 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
serverSocket->startAccepting();
// Connect to the server socket
......@@ -2253,6 +2255,61 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}
TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
EventBase eventBase;
TestConnectionEventCallback connectionEventCallback;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->setConnectionEventCallback(&connectionEventCallback);
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
bool acceptStartedFlag{false};
acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
acceptStartedFlag = true;
});
bool acceptStoppedFlag{false};
acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
acceptStoppedFlag = true;
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
ASSERT_TRUE(acceptStartedFlag);
ASSERT_TRUE(acceptStoppedFlag);
// Validate the connection event counters
ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
ASSERT_EQ(
connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}
/**
* Test AsyncServerSocket::getNumPendingMessagesInQueue()
*/
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment