Commit 9b15aded authored by Brandon Schlinker's avatar Brandon Schlinker Committed by Facebook GitHub Bot

Fix EOR bug, always pass timestamp flags

Summary:
Socket timestamps (ACK / TX) and EoR tracking currently break for `AsyncSSLSocket` if SSL renegotiation occurs while a timestamped write / EoR write is in progress.

- If EoR tracking is enabled, the EoR flag and any timestamp flags are not included until `AsyncSSLSocket` writes a buffer containing the final byte to the socket. This is to avoid these flags from being set on a partial write of the passed in buffer.
- If a write occurs while an SSL renegotiation is in progress, OpenSSL will return `SSL_ERROR_WANT_WRITE`. When this happens, we need to call write again, passing back in the same buffer.
- The current logic for deciding whether to include the EoR and timestamping flags (`eorAwareSSLWrite`) adds the number of bytes pending to the value returned by `AsyncSSLSocket::getRawBytesWritten` to determine the minimum byte offset when the flags should be added.
  - However, when a write fails due to SSL renegotiation, `getRawBytesWritten` may include some of the bytes that were passed in the last call, despite how they have not actually been written to the transport yet. This is because `getRawBytesWritten` is calculated based on the BIO chain length.
  - As a result, the current logic for calculating the offset on which to add the flags overshoots -- it returns an offset that will never be written. This causes the flags to not be added, and timestamps to timeout.
- This results in one of two things:
  - Timestamp timeouts, where the timestamps are never received
  - If a subsequent write is timestamped, the timestamps from that write may be used instead. This will cause the timestamps to be inflated, and leads to higher TX / ACK times at upper percentiles.

Fix is as follows:
- Change the logic that determines whether the EoR is included in the buffer to no longer rely on `getRawBytesWritten`. In addition, simplify logic so that it is no longer a separate function and easier to make sense of.
- Even if EoR tracking is enabled, always write timestamp flags (TX, ACK, etc.) on every write. This reduces the amount of coordination required between different components. The socket error message handler will end up with more cases of timestamps being delivered for earlier bytes than the last body byte, but they already need to deal with that today due to partial writes.

I considered just outright removing support for EoR tracking (EoR was previously used for timestamping) but decided against this as there's still some value in setting the EoR flag for debugging; see notes in code.

Reviewed By: yfeldblum

Differential Revision: D21969420

fbshipit-source-id: db8e7e5fbd70d627f88f2c43199387f5112b5f9e
parent 7f1bda25
......@@ -390,12 +390,7 @@ std::string AsyncSSLSocket::getApplicationProtocol() const noexcept {
}
void AsyncSSLSocket::setEorTracking(bool track) {
if (isEorTrackingEnabled() != track) {
AsyncSocket::setEorTracking(track);
appEorByteNo_ = 0;
appEorByteWriteFlags_ = {};
minEorRawByteNo_ = 0;
}
AsyncSocket::setEorTracking(track);
}
size_t AsyncSSLSocket::getRawBytesWritten() const {
......@@ -1008,7 +1003,7 @@ bool AsyncSSLSocket::willBlock(
int* sslErrorOut,
unsigned long* errErrorOut) noexcept {
*errErrorOut = 0;
int error = *sslErrorOut = SSL_get_error(ssl_.get(), ret);
int error = *sslErrorOut = sslGetErrorImpl(ssl_.get(), ret);
if (error == SSL_ERROR_WANT_READ) {
// Register for read event if not already.
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
......@@ -1408,7 +1403,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
}
if (bytes <= 0) {
int error = SSL_get_error(ssl_.get(), bytes);
int error = sslGetErrorImpl(ssl_.get(), bytes);
if (error == SSL_ERROR_WANT_READ) {
// The caller will register for read event if not already.
if (errno == EWOULDBLOCK || errno == EAGAIN) {
......@@ -1603,22 +1598,61 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
}
}
// cork the current write if the original flags included CORK or if there
// are remaining iovec to write
corkCurrentWrite_ =
isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count);
// track the EoR if:
// (1) there are write flags that require EoR tracking (EOR / TIMESTAMP_TX)
// (2) if the buffer includes the EOR byte
appEorByteWriteFlags_ = flags & kEorRelevantWriteFlags;
bool trackEor = appEorByteWriteFlags_ != folly::WriteFlags::NONE &&
(i + buffersStolen + 1 == count);
bytes = eorAwareSSLWrite(ssl_, sslWriteBuf, int(len), trackEor);
// From here, the write flow is as follows:
// - sslWriteImpl calls SSL_write, which encrypts the passed buffer.
// - SSL_write calls AsyncSSLSocket::bioWrite with the encrypted buffer.
// - AsyncSSLSocket::bioWrite calls AsyncSocket::sendSocketMessage(...).
//
// When sendSocketMessage calls sendMsg, WriteFlags are transformed into
// ancillary data and/or sendMsg flags. If WriteFlag::EOR is in flags and
// trackEor_ is set, then we should ensure that MSG_EOR is only passed to
// sendmsg when the final byte of the orginally passed in buffer is being
// written. Since the buffer originally passed to performWrite may be split
// up and written over multiple calls to sendmsg, we have to take care to
// unset the EOR flag if it was included in the WriteFlags passed in and
// we're writing a buffer that does _not_ contain the final byte of the
// orignally passed buffer.
//
// We handle EOR as follows:
// - We set currWriteFlags_ to the passed in WriteFlags.
// - If sslWriteBuf does NOT contain the last byte of the passed in iovec,
// then we set currBytesToFinalByte_ to folly::none. In bioWrite, we
// unset WriteFlags::EOR if it is set in currWriteFlags_.
// - If sslWriteBuf DOES contain the last byte of the passed in iovec,
// then we set bytesToFinalByte_ to int(len). In bioWrite, if the length
// of the passed in buffer >= currBytesToFinalByte_, then we leave the
// flags in currWriteFlags_ alone.
//
// What about timestamp flags?
// - We don't do any special handling for timestamping flags.
// - This may mean that more timestamps than necessary get generated, but
// that's OK; you already have to deal with that for timestamping due to
// the possibility of partial writes.
// - MSG_EOR used to be used for timestamping, but hasn't been for years.
//
// Finally, why even care about MSG_EOR, if not for timestamping?
// - If set, it is marked in the corresponding tcp_skb_cb; this can be
// useful when debugging.
// - The kernel uses it to decide whether socket buffers can be collapsed
// together (see tcp_skb_can_collapse_to).
currWriteFlags_ = flags;
uint32_t iovecWrittenToSslWriteBuf = i + buffersStolen + 1;
CHECK_LE(iovecWrittenToSslWriteBuf, count);
if (iovecWrittenToSslWriteBuf == count) { // last byte is in sslWriteBuf
currBytesToFinalByte_ = len; // length of current buffer
} else { // there are still remaining buffers / iovec to write
currBytesToFinalByte_ = folly::none;
currWriteFlags_ |= WriteFlags::CORK;
}
bytes = sslWriteImpl(ssl_.get(), sslWriteBuf, int(len));
if (bytes <= 0) {
int error = SSL_get_error(ssl_.get(), int(bytes));
int error = sslGetErrorImpl(ssl_.get(), int(bytes));
if (error == SSL_ERROR_WANT_WRITE) {
// The entire buffer needs to be passed in again, so *partialWritten
// is set to the original offset where we started for this call to
// performWrite(); see SSL_ERROR_WANT_WRITE documentation for details.
//
// The caller will register for write event if not already.
*partialWritten = uint32_t(offset);
return WriteResult(totalWritten);
......@@ -1651,42 +1685,6 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
return WriteResult(totalWritten);
}
int AsyncSSLSocket::eorAwareSSLWrite(
const ssl::SSLUniquePtr& ssl,
const void* buf,
int n,
bool eor) {
if (eor && isEorTrackingEnabled()) {
if (appEorByteNo_) {
// cannot track for more than one app byte EOR
CHECK(appEorByteNo_ == appBytesWritten_ + n);
} else {
appEorByteNo_ = appBytesWritten_ + n;
}
// 1. It is fine to keep updating minEorRawByteNo_.
// 2. It is _min_ in the sense that SSL record will add some overhead.
minEorRawByteNo_ = getRawBytesWritten() + n;
}
n = sslWriteImpl(ssl.get(), buf, n);
if (n > 0) {
appBytesWritten_ += n;
if (appEorByteNo_) {
if (getRawBytesWritten() >= minEorRawByteNo_) {
minEorRawByteNo_ = 0;
}
if (appBytesWritten_ == appEorByteNo_) {
appEorByteNo_ = 0;
appEorByteWriteFlags_ = {};
} else {
CHECK(appBytesWritten_ < appEorByteNo_);
}
}
}
return n;
}
void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
......@@ -1709,46 +1707,42 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
}
int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
// get pointer to AsyncSSLSocket from BioAppData
auto appData = OpenSSLUtils::getBioAppData(b);
CHECK(appData);
AsyncSSLSocket* sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
CHECK(sslSock);
// if EOR is tracked, correct if needed
WriteFlags flags = sslSock->currWriteFlags_;
if (sslSock->trackEor_ &&
(!sslSock->currBytesToFinalByte_.has_value() ||
*(sslSock->currBytesToFinalByte_) > (size_t)inl)) {
// unset EOR if set, since we're not writing the last byte yet
flags = unSet(flags, folly::WriteFlags::EOR);
}
struct msghdr msg;
struct iovec iov;
AsyncSSLSocket* tsslSock;
iov.iov_base = const_cast<char*>(in);
iov.iov_len = size_t(inl);
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
auto appData = OpenSSLUtils::getBioAppData(b);
CHECK(appData);
tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
CHECK(tsslSock);
WriteFlags flags = WriteFlags::NONE;
if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags |= tsslSock->appEorByteWriteFlags_;
}
if (tsslSock->corkCurrentWrite_) {
flags |= WriteFlags::CORK;
}
int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(
flags, false /*zeroCopyEnabled*/);
int msg_flags =
sslSock->getSendMsgParamsCB()->getFlags(flags, false /*zeroCopyEnabled*/);
msg.msg_controllen =
tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
sslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
CHECK_GE(
AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
msg.msg_controllen);
if (msg.msg_controllen != 0) {
msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
sslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
}
auto result =
tsslSock->sendSocketMessage(OpenSSLUtils::getBioFd(b), &msg, msg_flags);
sslSock->sendSocketMessage(OpenSSLUtils::getBioFd(b), &msg, msg_flags);
BIO_clear_retry_flags(b);
if (!result.exception && result.writeReturn <= 0) {
if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
......
......@@ -862,11 +862,16 @@ class AsyncSSLSocket : public AsyncSocket {
uint32_t* countWritten,
uint32_t* partialWritten);
// This virtual wrapper around SSL_write exists solely for testing/mockability
// Virtual wrapper around SSL_write, solely for testing/mockability
virtual int sslWriteImpl(SSL* ssl, const void* buf, int n) {
return SSL_write(ssl, buf, n);
}
// Virtual wrapper around SSL_get_error, solely for testing/mockability
virtual int sslGetErrorImpl(const SSL* s, int ret_code) {
return SSL_get_error(s, ret_code);
}
/**
* Apply verification options passed to sslConn/sslAccept or those set
* in the underlying SSLContext object.
......@@ -884,21 +889,6 @@ class AsyncSSLSocket : public AsyncSocket {
*/
bool setupSSLBio();
/**
* A SSL_write wrapper that understand EOR
*
* @param ssl: SSL pointer
* @param buf: Buffer to be written
* @param n: Number of bytes to be written
* @param eor: Does the last byte (buf[n-1]) have the app-last-byte?
* @return: The number of app bytes successfully written to the socket
*/
int eorAwareSSLWrite(
const ssl::SSLUniquePtr& ssl,
const void* buf,
int n,
bool eor);
// Inherit error handling methods from AsyncSocket, plus the following.
void failHandshake(const char* fn, const AsyncSocketException& ex);
......@@ -931,32 +921,17 @@ class AsyncSSLSocket : public AsyncSocket {
Timeout handshakeTimeout_;
Timeout connectionTimeout_;
// The app byte num that we are tracking for EOR.
//
// Only one app EOR byte can be tracked.
// See appEorByteWriteFlags_ for details.
size_t appEorByteNo_{0};
// The WriteFlags to pass for the app byte num that is tracked for EOR.
//
// When openssl is about to send appEorByteNo_, these flags will be passed to
// the application via the getAncillaryData callback. The application can then
// generate a control message containing socket timestamping flags or other
// commands that will be included when the corresponding buffer is passed to
// the kernel via sendmsg().
//
// See AsyncSSLSocket::bioWrite (which overrides OpenSSL biowrite).
WriteFlags appEorByteWriteFlags_{};
// WriteFlags last passed to performWrite
WriteFlags currWriteFlags_{};
// Number of bytes to write before final byte
// See AsyncSSLSocket::performWrite for details
folly::Optional<size_t> currBytesToFinalByte_;
// Try to avoid calling SSL_write() for buffers smaller than this.
// It doesn't take effect when it is 0.
size_t minWriteSize_{1500};
// When openssl is about to sendmsg() across the minEorRawBytesNo_,
// it will trigger logic to include an application defined control message.
//
// See appEorByteWriteFlags_ for details.
size_t minEorRawByteNo_{0};
#if FOLLY_OPENSSL_HAS_SNI
std::shared_ptr<folly::SSLContext> handshakeCtx_;
std::string tlsextHostname_;
......
......@@ -1169,7 +1169,7 @@ class AsyncSocket : public AsyncTransport {
* @param msg Message to send
* @param msg_flags Flags to pass to sendmsg
*/
AsyncSocket::WriteResult
virtual AsyncSocket::WriteResult
sendSocketMessage(NetworkSocket fd, struct msghdr* msg, int msg_flags);
virtual ssize_t
......
......@@ -120,19 +120,6 @@ constexpr bool isSet(WriteFlags a, WriteFlags b) {
return (a & b) == b;
}
/**
* Write flags that are specifically for the final write call of a buffer.
*
* In some cases, buffers passed to send may be coalesced or split by the socket
* write handling logic. For instance, a buffer passed to AsyncSSLSocket may be
* split across multiple TLS records (and therefore multiple calls to write).
*
* When a buffer is split up, these flags will only be applied for the final
* call to write for that buffer.
*/
constexpr WriteFlags kEorRelevantWriteFlags =
WriteFlags::EOR | WriteFlags::TIMESTAMP_TX;
class AsyncReader {
public:
class ReadCallback {
......
......@@ -37,6 +37,7 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
new MockAsyncSSLSocket(ctx, evb), Destructor());
sock->ssl_.reset(SSL_new(ctx->getSSLCtx()));
SSL_set_fd(sock->ssl_.get(), -1);
sock->setupSSLBio();
return sock;
}
......@@ -51,6 +52,17 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
// mock the calls to SSL_write to see the buffer length and contents
MOCK_METHOD3(sslWriteImpl, int(SSL* ssl, const void* buf, int n));
// mock the calls to SSL_get_error to insert errors
MOCK_METHOD2(sslGetErrorImpl, int(const SSL* s, int ret_code));
// mock the calls to sendSocketMessage to see the msg_flags
MOCK_METHOD3(
sendSocketMessage,
AsyncSocket::WriteResult(
NetworkSocket fd,
struct msghdr* msg,
int msg_flags));
// mock the calls to getRawBytesWritten()
MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
......@@ -64,14 +76,14 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
return performWrite(vec, count, flags, countWritten, partialWritten);
}
void checkEor(size_t appEor, size_t rawEor) {
EXPECT_EQ(appEor, appEorByteNo_);
EXPECT_EQ(rawEor, minEorRawByteNo_);
}
void setAppBytesWritten(size_t n) {
appBytesWritten_ = n;
}
// public wrapper for protected member
folly::Optional<size_t> getCurrBytesToFinalByte() const {
return currBytesToFinalByte_;
}
};
class AsyncSSLSocketWriteTest : public testing::Test {
......@@ -116,50 +128,132 @@ class AsyncSSLSocketWriteTest : public testing::Test {
char source_[26 * 500];
};
// SSL_ERROR_WANT_WRITE occurs on first write
TEST_F(AsyncSSLSocketWriteTest, SslErrorWantWrite) {
int n = 1;
auto vec = makeVec({1500});
int pos = 0;
// first time we try to write, SSL_ERROR_WANT_WRITE will be returned
//
// this means no bytes were actually written to the socket,
// but getRawBytesWritten will still be incremented by the write size as
// the bytes were appended to the BIO
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()), sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(0))));
EXPECT_CALL(*(sock_.get()), sslGetErrorImpl(_, _))
.WillOnce(Return(SSL_ERROR_WANT_WRITE));
ON_CALL( // should not be called, unless implementation changes to use it
*(sock_.get()),
getRawBytesWritten())
.WillByDefault(Return(1500));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 0);
EXPECT_EQ(partialWritten, 0);
// second time we try to write, same buffer should be passed in
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()), sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1500))));
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// The entire vec fits in one packet
TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescing1) {
int n = 3;
auto vec = makeVec({3, 3, 3});
int pos = 0;
InSequence s;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
.WillOnce(Invoke([this](SSL*, const void* buf, int m) {
verifyVec(buf, m, 0);
return 9;
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL)) // no MSG_MORE
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(9))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// First packet is full, second two go in one packet
TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescing2) {
int n = 3;
auto vec = makeVec({1500, 3, 3});
int pos = 0;
InSequence s;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1500))));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL)) // no MSG_MORE
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(6))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// Two exactly full packets (coalesce ends midway through second chunk)
TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescing3) {
int n = 3;
auto vec = makeVec({1000, 1000, 1000});
int pos = 0;
......@@ -174,138 +268,187 @@ TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// Partial write success midway through a coalesced vec
TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescing4) {
int n = 5;
auto vec = makeVec({300, 300, 300, 300, 300});
int pos = 0;
InSequence s1;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += 1000;
return 1000; /* 500 bytes "pending" */
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL)) // no MSG_MORE
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1000))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 3);
EXPECT_EQ(partialWritten, 100);
consumeVec(vec.get(), countWritten, partialWritten);
InSequence s2;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += m;
return 500;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()), sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(500))));
sock_->testPerformWrite(
vec.get() + countWritten,
n - countWritten,
WriteFlags::NONE,
&countWritten,
&partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 2);
EXPECT_EQ(partialWritten, 0);
}
// coalesce ends exactly on a buffer boundary
TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescing5) {
int n = 3;
auto vec = makeVec({1000, 500, 500});
int pos = 0;
InSequence s;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1500))));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()), sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(500))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 3);
EXPECT_EQ(partialWritten, 0);
}
// partial write midway through first chunk
TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescing6) {
int n = 2;
auto vec = makeVec({1000, 500});
int pos = 0;
InSequence s1;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += 700;
return 700;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL)) // no MSG_MORE
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(700))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::NONE, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 0);
EXPECT_EQ(partialWritten, 700);
consumeVec(vec.get(), countWritten, partialWritten);
InSequence s2;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()), sendSocketMessage(_, _, MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(800))));
sock_->testPerformWrite(
vec.get() + countWritten,
n - countWritten,
WriteFlags::NONE,
&countWritten,
&partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 2);
EXPECT_EQ(partialWritten, 0);
}
// Repeat coalescing2 with WriteFlags::EOR
TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescingWithEoRTracking1) {
int n = 3;
auto vec = makeVec({1500, 3, 3});
int pos = 0;
const size_t initAppBytesWritten = 500;
const size_t appEor = initAppBytesWritten + 1506;
sock_->setAppBytesWritten(initAppBytesWritten);
EXPECT_FALSE(sock_->isEorTrackingEnabled());
sock_->setEorTracking(true);
EXPECT_TRUE(sock_->isEorTrackingEnabled());
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
// rawBytesWritten after writing initAppBytesWritten + 1500
// + some random SSL overhead
.WillOnce(Return(3600u))
// rawBytesWritten after writing last 6 bytes
// + some random SSL overhead
.WillOnce(Return(3728u));
InSequence s;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
// the first 1500 does not have the EOR byte
sock_->checkEor(0, 0);
EXPECT_EQ(folly::none, sock_->getCurrBytesToFinalByte());
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1500))));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
.WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
sock_->checkEor(appEor, 3600 + m);
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_EOR | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(6))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
......@@ -313,43 +456,44 @@ TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
sock_->checkEor(0, 0);
}
// coalescing with left over at the last chunk
// WriteFlags::EOR turned on
TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescingWithEoRTracking2) {
int n = 3;
auto vec = makeVec({600, 600, 600});
int pos = 0;
const size_t initAppBytesWritten = 500;
const size_t appEor = initAppBytesWritten + 1800;
sock_->setAppBytesWritten(initAppBytesWritten);
sock_->setEorTracking(true);
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
// rawBytesWritten after writing initAppBytesWritten + 1500 bytes
// + some random SSL overhead
.WillOnce(Return(3600))
// rawBytesWritten after writing last 300 bytes
// + some random SSL overhead
.WillOnce(Return(4100));
InSequence s;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
// the first 1500 does not have the EOR byte
sock_->checkEor(0, 0);
EXPECT_EQ(folly::none, sock_->getCurrBytesToFinalByte());
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1500))));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
.WillOnce(Invoke([=, &pos](SSL*, const void* buf, int m) {
sock_->checkEor(appEor, 3600 + m);
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_EOR | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(300))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
......@@ -357,66 +501,127 @@ TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
sock_->checkEor(0, 0);
}
// WriteFlags::EOR set
// One buf in iovec
// Partial write at 1000-th byte
TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescingWithEoRTracking3) {
int n = 1;
auto vec = makeVec({1600});
int pos = 0;
static constexpr size_t initAppBytesWritten = 500;
static constexpr size_t appEor = initAppBytesWritten + 1600;
sock_->setAppBytesWritten(initAppBytesWritten);
sock_->setEorTracking(true);
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
// rawBytesWritten after the initAppBytesWritten
// + some random SSL overhead
.WillOnce(Return(2000))
// rawBytesWritten after the initAppBytesWritten + 1000 (with 100
// overhead)
// + some random SSL overhead
.WillOnce(Return(3100));
InSequence s;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
sock_->checkEor(appEor, 2000 + m);
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
// partial write of 1000 bytes
// currBytesToFinalByte should be 1600 at this point; expect full write
EXPECT_EQ(1600, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
pos += 1000;
return 1000;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_EOR | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1000))));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 0);
EXPECT_EQ(partialWritten, 1000);
sock_->checkEor(appEor, 2000 + 1600);
consumeVec(vec.get(), countWritten, partialWritten);
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
.WillOnce(Return(3100))
.WillOnce(Return(3800));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
.WillOnce(Invoke([this, &pos](SSL*, const void* buf, int m) {
sock_->checkEor(appEor, 3100 + m);
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
pos += m;
return m;
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_EOR | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(600))));
sock_->testPerformWrite(
vec.get() + countWritten,
n - countWritten,
WriteFlags::EOR,
&countWritten,
&partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// WriteFlags::EOR set
// SSL_ERROR_WANT_WRITE occurs on first write
TEST_F(AsyncSSLSocketWriteTest, WriteCoalescingWithEoRTrackingErrorWantWrite) {
int n = 1;
auto vec = makeVec({1500});
int pos = 0;
sock_->setEorTracking(true);
// first time we try to write, SSL_ERROR_WANT_WRITE will be returned
//
// this means no bytes were actually written to the socket,
// but getRawBytesWritten will still be incremented by the write size as
// the bytes were appended to the BIO
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_EOR | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(0))));
EXPECT_CALL(*(sock_.get()), sslGetErrorImpl(_, _))
.WillOnce(Return(SSL_ERROR_WANT_WRITE));
ON_CALL( // should not be called, unless implementation changes to use it
*(sock_.get()),
getRawBytesWritten())
.WillByDefault(Return(1500));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(
vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, 0);
EXPECT_EQ(partialWritten, 0);
// second time we try to write, no error
// EOR should still be set
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos](SSL* ssl, const void* buf, int m) {
EXPECT_EQ(m, sock_->getCurrBytesToFinalByte().value_or(0));
verifyVec(buf, m, pos);
BIO* b = SSL_get_wbio(ssl);
auto result = AsyncSSLSocket::bioWrite(b, (const char*)buf, m);
pos += result;
return result;
}));
EXPECT_CALL(
*(sock_.get()),
sendSocketMessage(_, _, MSG_EOR | MSG_DONTWAIT | MSG_NOSIGNAL))
.WillOnce(Return(ByMove(AsyncSocket::WriteResult(1500))));
sock_->testPerformWrite(
vec.get(), n, WriteFlags::EOR, &countWritten, &partialWritten);
Mock::VerifyAndClearExpectations(sock_.get());
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
sock_->checkEor(0, 0);
}
} // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment