Commit cd09aea8 authored by Misha Shneerson's avatar Misha Shneerson Committed by Facebook Github Bot

short-circuit connection callbacks onto same thread

Summary:
When we have a single thread handling both connections and requests, we end up
call the accept callback on the same thread. The code handling the
incoming connections does intend to short-circuit such connection
and not place them on EventBase's queue but it only checks only whether there is no
EventBase registerd for the accept callback. However, by default, we are
registering an EventBase for this callback so this short-circuit falls short
(yay puns).

Reviewed By: djwatson

Differential Revision: D8986246

fbshipit-source-id: 45b817669ae4fd908b39c93ae5b82bb9a14cc2ed
parent a4ef724b
...@@ -927,7 +927,7 @@ void AsyncServerSocket::dispatchSocket(int socket, SocketAddress&& address) { ...@@ -927,7 +927,7 @@ void AsyncServerSocket::dispatchSocket(int socket, SocketAddress&& address) {
// Short circuit if the callback is in the primary EventBase thread // Short circuit if the callback is in the primary EventBase thread
CallbackInfo* info = nextCallback(); CallbackInfo* info = nextCallback();
if (info->eventBase == nullptr) { if (info->eventBase == nullptr || info->eventBase == this->eventBase_) {
info->callback->connectionAccepted(socket, address); info->callback->connectionAccepted(socket, address);
return; return;
} }
...@@ -994,7 +994,7 @@ void AsyncServerSocket::dispatchError(const char* msgstr, int errnoValue) { ...@@ -994,7 +994,7 @@ void AsyncServerSocket::dispatchError(const char* msgstr, int errnoValue) {
while (true) { while (true) {
// Short circuit if the callback is in the primary EventBase thread // Short circuit if the callback is in the primary EventBase thread
if (info->eventBase == nullptr) { if (info->eventBase == nullptr || info->eventBase == this->eventBase_) {
std::runtime_error ex( std::runtime_error ex(
std::string(msgstr) + folly::to<std::string>(errnoValue)); std::string(msgstr) + folly::to<std::string>(errnoValue));
info->callback->acceptError(ex); info->callback->acceptError(ex);
......
...@@ -598,7 +598,9 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -598,7 +598,9 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
} }
int64_t numMsgs = 0; int64_t numMsgs = 0;
for (const auto& callback : callbacks_) { for (const auto& callback : callbacks_) {
numMsgs += callback.consumer->getQueue()->size(); if (callback.consumer) {
numMsgs += callback.consumer->getQueue()->size();
}
} }
return numMsgs; return numMsgs;
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <folly/io/async/AsyncSocket.h> #include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/experimental/TestUtil.h> #include <folly/experimental/TestUtil.h>
#include <folly/io/IOBuf.h> #include <folly/io/IOBuf.h>
...@@ -31,6 +32,7 @@ ...@@ -31,6 +32,7 @@
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
#include <folly/portability/Sockets.h> #include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h> #include <folly/portability/Unistd.h>
#include <folly/synchronization/Baton.h>
#include <folly/test/SocketAddressTestHelper.h> #include <folly/test/SocketAddressTestHelper.h>
#include <boost/scoped_array.hpp> #include <boost/scoped_array.hpp>
...@@ -2161,8 +2163,8 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) { ...@@ -2161,8 +2163,8 @@ TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0); ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0); ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
ASSERT_EQ( ASSERT_EQ(
connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1); connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1); ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0); ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0); ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0); ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
...@@ -2219,6 +2221,59 @@ TEST(AsyncSocketTest, CallbackInPrimaryEventBase) { ...@@ -2219,6 +2221,59 @@ TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0); ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
} }
TEST(AsyncSocketTest, CallbackInSecondaryEventBase) {
EventBase eventBase;
TestConnectionEventCallback connectionEventCallback;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->setConnectionEventCallback(&connectionEventCallback);
serverSocket->bind(0);
serverSocket->listen(16);
SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
ScopedEventBaseThread cobThread("ioworker_test");
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const SocketAddress& /* addr */) {
eventBase.runInEventBaseThread([&] {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
eventBase.runInEventBaseThread(
[&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
});
std::atomic<bool> acceptStartedFlag{false};
acceptCallback.setAcceptStartedFn([&]() { acceptStartedFlag = true; });
Baton<> acceptStoppedFlag;
acceptCallback.setAcceptStoppedFn([&]() { acceptStoppedFlag.post(); });
serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
ASSERT_TRUE(acceptStoppedFlag.try_wait_for(std::chrono::seconds(1)));
ASSERT_TRUE(acceptStartedFlag);
// Validate the connection event counters
ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
ASSERT_EQ(
connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}
/** /**
* Test AsyncServerSocket::getNumPendingMessagesInQueue() * Test AsyncServerSocket::getNumPendingMessagesInQueue()
*/ */
...@@ -2236,21 +2291,25 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) { ...@@ -2236,21 +2291,25 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
serverSocket->getAddress(&serverAddress); serverSocket->getAddress(&serverAddress);
// Add a callback to accept connections // Add a callback to accept connections
folly::ScopedEventBaseThread cobThread("ioworker_test");
TestAcceptCallback acceptCallback; TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn( acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) { [&](int /* fd */, const folly::SocketAddress& /* addr */) {
count++; count++;
ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue()); eventBase.runInEventBaseThreadAndWait([&] {
ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
});
if (count == 4) { if (count == 4) {
// all messages are processed, remove accept callback eventBase.runInEventBaseThread([&] {
serverSocket->removeAcceptCallback(&acceptCallback, &eventBase); serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
} }
}); });
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) { acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
serverSocket->removeAcceptCallback(&acceptCallback, &eventBase); eventBase.runInEventBaseThread(
[&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
}); });
serverSocket->addAcceptCallback(&acceptCallback, &eventBase); serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
serverSocket->startAccepting(); serverSocket->startAccepting();
// Connect to the server socket, 4 clients, there are 4 connections // Connect to the server socket, 4 clients, there are 4 connections
...@@ -2260,6 +2319,90 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) { ...@@ -2260,6 +2319,90 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress)); auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop(); eventBase.loop();
ASSERT_EQ(4, count);
}
TEST(AsyncSocketTest, ConnectionsStorm) {
enum class AcceptCobLocation {
Default,
Primary,
Secondary,
};
auto testFunc = [](AcceptCobLocation mode) {
EventBase eventBase;
// Counter of how many connections have been accepted
std::atomic<size_t> count{0};
// Create a server socket
auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
serverSocket->bind(0);
serverSocket->listen(100);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
TestConnectionEventCallback connectionEventCallback;
serverSocket->setConnectionEventCallback(&connectionEventCallback);
// Add a callback to accept connections
std::shared_ptr<ScopedEventBaseThread> thread;
TestAcceptCallback acceptCallback;
bool stopAccepting = false;
const size_t maxSockets = 2000;
acceptCallback.setConnectionAcceptedFn(
[&](int /* fd */, const folly::SocketAddress& /* addr */) {
count++;
if (!stopAccepting &&
(count == maxSockets ||
connectionEventCallback.getConnectionDropped() > 0)) {
stopAccepting = true;
eventBase.runInEventBaseThread([&] {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
}
});
acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
eventBase.runInEventBaseThread([&] {
stopAccepting = true;
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
});
if (mode == AcceptCobLocation::Default) {
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
} else if (mode == AcceptCobLocation::Primary) {
serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
} else if (mode == AcceptCobLocation::Secondary) {
thread = std::make_shared<ScopedEventBaseThread>();
serverSocket->addAcceptCallback(&acceptCallback, thread->getEventBase());
}
serverSocket->startAccepting();
// Create connection storm to create connections fast but
// also pace it to not overflow servers' listening queue.
vector<std::shared_ptr<AsyncSocket>> sockets;
folly::Function<void()> fnOpenSockets = [&]() {
// Counter of connections pending the invocation of accept callback.
auto pending = serverSocket->getNumPendingMessagesInQueue();
while (sockets.size() < std::min(maxSockets, pending + count + 30)) {
auto socket = folly::AsyncSocket::newSocket(&eventBase);
socket->connect(nullptr, serverAddress, 5000);
sockets.push_back(socket);
}
if (sockets.size() < maxSockets && !stopAccepting) {
eventBase.runInEventBaseThread([&] { fnOpenSockets(); });
}
};
eventBase.runInEventBaseThread([&] { fnOpenSockets(); });
eventBase.loop();
ASSERT_EQ(maxSockets, count);
};
testFunc(AcceptCobLocation::Default);
testFunc(AcceptCobLocation::Primary);
testFunc(AcceptCobLocation::Secondary);
} }
/** /**
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment