Commit ff508994 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 2

Simplify TFO write path

Summary:
We currently call handleInitialReadWrite.
The reason for this is that if the read callback
was set before TFO was done with connecting, then
we need to call handleinitialreadwrite to setup the
read callback similar to how connect invokes handleInitialReadWrite
after it's done.

However handleinitalreadwrite may also call handleWrite
if writeReqHead_ is non null.
Practically this will not happen since TFO will happen on
the first write only where writeReqHead_ will be null.

The current code path though is a little bit complicated.
This simplfies the code so that we dont need to potentially
call handleWrite within a write call.

We schedule the initial readwrite call asynchrously.
The reason for this is that handleReadWrite can actually fail if updating
events fails. This might cause weird state issues once it returns and we
have no mechanism of processing it.

Reviewed By: djwatson

Differential Revision: D3695925

fbshipit-source-id: 72e19a9e1802caa14e872e05a5cd9bf4e34c5e7d
parent 2ed41baf
...@@ -174,19 +174,19 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { ...@@ -174,19 +174,19 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
}; };
AsyncSocket::AsyncSocket() AsyncSocket::AsyncSocket()
: eventBase_(nullptr) : eventBase_(nullptr),
, writeTimeout_(this, nullptr) writeTimeout_(this, nullptr),
, ioHandler_(this, nullptr) ioHandler_(this, nullptr),
, immediateReadHandler_(this) { immediateReadHandler_(this) {
VLOG(5) << "new AsyncSocket()"; VLOG(5) << "new AsyncSocket()";
init(); init();
} }
AsyncSocket::AsyncSocket(EventBase* evb) AsyncSocket::AsyncSocket(EventBase* evb)
: eventBase_(evb) : eventBase_(evb),
, writeTimeout_(this, evb) writeTimeout_(this, evb),
, ioHandler_(this, evb) ioHandler_(this, evb),
, immediateReadHandler_(this) { immediateReadHandler_(this) {
VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")"; VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
init(); init();
} }
...@@ -207,10 +207,10 @@ AsyncSocket::AsyncSocket(EventBase* evb, ...@@ -207,10 +207,10 @@ AsyncSocket::AsyncSocket(EventBase* evb,
} }
AsyncSocket::AsyncSocket(EventBase* evb, int fd) AsyncSocket::AsyncSocket(EventBase* evb, int fd)
: eventBase_(evb) : eventBase_(evb),
, writeTimeout_(this, evb) writeTimeout_(this, evb),
, ioHandler_(this, evb, fd) ioHandler_(this, evb, fd),
, immediateReadHandler_(this) { immediateReadHandler_(this) {
VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
<< fd << ")"; << fd << ")";
init(); init();
...@@ -1615,7 +1615,6 @@ void AsyncSocket::handleInitialReadWrite() noexcept { ...@@ -1615,7 +1615,6 @@ void AsyncSocket::handleInitialReadWrite() noexcept {
// one here just to make sure, in case one of our calling code paths ever // one here just to make sure, in case one of our calling code paths ever
// changes. // changes.
DestructorGuard dg(this); DestructorGuard dg(this);
// If we have a readCallback_, make sure we enable read events. We // If we have a readCallback_, make sure we enable read events. We
// may already be registered for reads if connectSuccess() set // may already be registered for reads if connectSuccess() set
// the read calback. // the read calback.
...@@ -1772,7 +1771,9 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { ...@@ -1772,7 +1771,9 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
if (totalWritten >= 0) { if (totalWritten >= 0) {
tfoFinished_ = true; tfoFinished_ = true;
state_ = StateEnum::ESTABLISHED; state_ = StateEnum::ESTABLISHED;
handleInitialReadWrite(); // We schedule this asynchrously so that we don't end up
// invoking initial read or write while a write is in progress.
scheduleInitialReadWrite();
} else if (errno == EINPROGRESS) { } else if (errno == EINPROGRESS) {
VLOG(4) << "TFO falling back to connecting"; VLOG(4) << "TFO falling back to connecting";
// A normal sendmsg doesn't return EINPROGRESS, however // A normal sendmsg doesn't return EINPROGRESS, however
...@@ -1798,7 +1799,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { ...@@ -1798,7 +1799,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
// connect succeeded immediately // connect succeeded immediately
// Treat this like no data was written. // Treat this like no data was written.
state_ = StateEnum::ESTABLISHED; state_ = StateEnum::ESTABLISHED;
handleInitialReadWrite(); scheduleInitialReadWrite();
} }
// If there was no exception during connections, // If there was no exception during connections,
// we would return that no bytes were written. // we would return that no bytes were written.
......
...@@ -735,6 +735,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -735,6 +735,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
} }
} }
/**
* Schedule handleInitalReadWrite to run in the next iteration.
*/
void scheduleInitialReadWrite() noexcept {
if (good()) {
DestructorGuard dg(this);
eventBase_->runInLoop([this, dg] {
if (good()) {
handleInitialReadWrite();
}
});
}
}
// event notification methods // event notification methods
void ioReady(uint16_t events) noexcept; void ioReady(uint16_t events) noexcept;
virtual void checkForImmediateRead() noexcept; virtual void checkForImmediateRead() noexcept;
......
...@@ -2410,6 +2410,55 @@ TEST(AsyncSocketTest, ConnectTFO) { ...@@ -2410,6 +2410,55 @@ TEST(AsyncSocketTest, ConnectTFO) {
EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size())); EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
} }
TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
// Start listening on a local port
TestServer server(true);
// Connect using a AsyncSocket
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
socket->enableTFO();
ConnCallback cb;
socket->connect(&cb, server.getAddress(), 30);
ReadCallback rcb;
socket->setReadCB(&rcb);
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
std::array<uint8_t, 3> readBuf;
auto sendBuf = IOBuf::copyBuffer("hey");
std::thread t([&] {
auto acceptedSocket = server.accept();
acceptedSocket->write(buf.data(), buf.size());
acceptedSocket->flush();
acceptedSocket->readAll(readBuf.data(), readBuf.size());
acceptedSocket->close();
});
evb.loop();
CHECK_EQ(cb.state, STATE_SUCCEEDED);
EXPECT_LE(0, socket->getConnectTime().count());
EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
EXPECT_TRUE(socket->getTFOAttempted());
// Should trigger the connect
WriteCallback write;
socket->writeChain(&write, sendBuf->clone());
evb.loop();
t.join();
EXPECT_EQ(STATE_SUCCEEDED, write.state);
EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
ASSERT_EQ(1, rcb.buffers.size());
ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
}
/** /**
* Test connecting to a server that isn't listening * Test connecting to a server that isn't listening
*/ */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment