Commit 8ea0a28b authored by Brandon Schlinker's avatar Brandon Schlinker Committed by Facebook GitHub Bot

Track rawBytesWritten

Summary:
During the wangle accept process and in a few pieces of application code we transform one type of `AsyncSocket` to another, potentially *after* a write has occurred.

Since we do not currently carry `appBytesWritten` and `rawBytesWritten` during transformations, those values may not represent all bytes written. Likewise, if we transform from a `AsyncSSLSocket`, we lose count of the number of raw bytes written.

It's very difficult to reason about whether these problems will actually manifest, so I'd prefer to just guard against it.

With this change, we explicitly track the number of bytes written to the socket by incrementing a counter in `sendSocketMessage`, which is used by both `AsyncSocket` and `AsyncSSLSocket`. In addition, we copy the appBytesWritten and rawBytesWritten during socket moves.

Reviewed By: yfeldblum

Differential Revision: D24551958

fbshipit-source-id: 88416114b52931ff3ceef847401d556ccf0ab664
parent 4c2bc928
......@@ -445,7 +445,7 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
// get the write bytes of the last bio
BIO* b;
if (!ssl_ || !(b = SSL_get_wbio(ssl_.get()))) {
return 0;
return rawBytesWritten_;
}
BIO* next = BIO_next(b);
while (next != nullptr) {
......@@ -453,7 +453,12 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
next = BIO_next(b);
}
return BIO_number_written(b);
// Raw bytes written should be >= BIO_number_written(b)
// Verify no shadowing of rawBytesWritten_
DCHECK_GE(AsyncSocket::getRawBytesWritten(), BIO_number_written(b));
DCHECK_GE(rawBytesWritten_, BIO_number_written(b));
DCHECK_EQ(rawBytesWritten_, AsyncSocket::getRawBytesWritten());
return rawBytesWritten_;
}
size_t AsyncSSLSocket::getRawBytesReceived() const {
......
......@@ -360,6 +360,8 @@ AsyncSocket::AsyncSocket(AsyncSocket* oldAsyncSocket)
oldAsyncSocket->getEventBase(),
oldAsyncSocket->detachNetworkSocket(),
oldAsyncSocket->getZeroCopyBufId()) {
appBytesWritten_ = oldAsyncSocket->appBytesWritten_;
rawBytesWritten_ = oldAsyncSocket->rawBytesWritten_;
preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
// inform lifecycle observers to give them an opportunity to unsubscribe from
......@@ -397,6 +399,7 @@ void AsyncSocket::init() {
wShutdownSocketSet_.reset();
appBytesWritten_ = 0;
appBytesReceived_ = 0;
rawBytesWritten_ = 0;
totalAppBytesScheduledForWrite_ = 0;
sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
}
......@@ -2579,6 +2582,11 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
NetworkSocket fd, struct msghdr* msg, int msg_flags) {
ssize_t totalWritten = 0;
SCOPE_EXIT {
if (totalWritten > 0) {
rawBytesWritten_ += totalWritten;
}
};
if (state_ == StateEnum::FAST_OPEN) {
sockaddr_storage addr;
auto len = addr_.getAddress(&addr);
......
......@@ -634,7 +634,7 @@ class AsyncSocket : public AsyncTransport {
size_t getAppBytesWritten() const override { return appBytesWritten_; }
size_t getRawBytesWritten() const override { return getAppBytesWritten(); }
size_t getRawBytesWritten() const override { return rawBytesWritten_; }
size_t getAppBytesReceived() const override { return appBytesReceived_; }
......@@ -1420,6 +1420,7 @@ class AsyncSocket : public AsyncTransport {
std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_;
size_t appBytesReceived_; ///< Num of bytes received from socket
size_t appBytesWritten_; ///< Num of bytes written to socket
size_t rawBytesWritten_; ///< Num of (raw) bytes written to socket
// The total num of bytes passed to AsyncSocket's write functions. It doesn't
// include failed writes, but it does include buffered writes.
size_t totalAppBytesScheduledForWrite_;
......
......@@ -3163,6 +3163,52 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
EXPECT_EQ(sniStr, std::string("Baz"));
}
}
TEST(AsyncSSLSocketTest, BytesWrittenWithMove) {
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
auto sslContext = std::make_shared<SSLContext>();
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
auto socket1 =
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket1->open(std::chrono::milliseconds(10000));
// write
std::vector<uint8_t> wbuf(128, 'a');
socket1->write(wbuf.data(), wbuf.size());
const auto socket1AppBytes = socket1->getSocket()->getAppBytesWritten();
const auto socket1RawBytes = socket1->getSocket()->getRawBytesWritten();
EXPECT_EQ(128, socket1AppBytes);
EXPECT_LT(128, socket1RawBytes);
// read reflection
std::vector<uint8_t> readbuf(wbuf.size());
uint32_t bytesRead = socket1->readAll(readbuf.data(), readbuf.size());
EXPECT_EQ(bytesRead, wbuf.size());
// additional sanity checks on virtuals
EXPECT_EQ(
socket1->getSSLSocket()->getRawBytesWritten(),
socket1->getSocket()->getRawBytesWritten());
EXPECT_EQ(128, socket1->getSocket()->getAppBytesWritten());
EXPECT_EQ(128, socket1->getSSLSocket()->getAppBytesWritten());
// move to another AsyncSSLSocket
AsyncSSLSocket::UniquePtr socket2(
new AsyncSSLSocket(sslContext, socket1->getSocket()));
EXPECT_EQ(socket1AppBytes, socket2->getAppBytesWritten());
EXPECT_EQ(socket1RawBytes, socket2->getRawBytesWritten());
// move to an AsyncSocket
AsyncSocket::UniquePtr socket3(new AsyncSocket(std::move(socket2)));
EXPECT_EQ(socket1AppBytes, socket3->getAppBytesWritten());
EXPECT_EQ(socket1RawBytes, socket3->getRawBytesWritten());
}
} // namespace folly
#ifdef SIGPIPE
......
......@@ -3032,6 +3032,29 @@ TEST(AsyncSocketTest, TestEvbDetachThenClose) {
socket.reset();
}
TEST(AsyncSocket, BytesWrittenWithMove) {
TestServer server;
EventBase evb;
auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
ConnCallback ccb;
socket1->connect(&ccb, server.getAddress(), 30);
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
EXPECT_EQ(0, socket1->getRawBytesWritten());
std::vector<uint8_t> wbuf(128, 'a');
WriteCallback wcb;
socket1->write(&wcb, wbuf.data(), wbuf.size());
evb.loopOnce();
ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
EXPECT_EQ(wbuf.size(), socket1->getRawBytesWritten());
EXPECT_EQ(wbuf.size(), socket1->getAppBytesWritten());
auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
EXPECT_EQ(wbuf.size(), socket2->getRawBytesWritten());
EXPECT_EQ(wbuf.size(), socket2->getAppBytesWritten());
}
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
/* copied from include/uapi/linux/net_tstamp.h */
/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment