Commit b6a67aee authored by Xiangyu Bu's avatar Xiangyu Bu Committed by Facebook Github Bot

Clang-format AsyncSSLSocketTest.cpp.

Summary: ... as titled.

Reviewed By: yfeldblum

Differential Revision: D5558742

fbshipit-source-id: b63b121cde8db93de4cabc80563539297611d600
parent 647dba2f
...@@ -69,8 +69,8 @@ void getfds(int fds[2]) { ...@@ -69,8 +69,8 @@ void getfds(int fds[2]) {
<< strerror(errno); << strerror(errno);
} }
if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) { if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
FAIL() << "failed to put socket " << idx << " in non-blocking mode: " FAIL() << "failed to put socket " << idx
<< strerror(errno); << " in non-blocking mode: " << strerror(errno);
} }
} }
} }
...@@ -94,28 +94,32 @@ void sslsocketpair( ...@@ -94,28 +94,32 @@ void sslsocketpair(
int fds[2]; int fds[2];
getfds(fds); getfds(fds);
getctx(clientCtx, serverCtx); getctx(clientCtx, serverCtx);
clientSock->reset(new AsyncSSLSocket( clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
clientCtx, eventBase, fds[0], false)); serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
serverSock->reset(new AsyncSSLSocket(
serverCtx, eventBase, fds[1], true));
// (*clientSock)->setSendTimeout(100); // (*clientSock)->setSendTimeout(100);
// (*serverSock)->setSendTimeout(100); // (*serverSock)->setSendTimeout(100);
} }
// client protocol filters // client protocol filters
bool clientProtoFilterPickPony(unsigned char** client, bool clientProtoFilterPickPony(
unsigned int* client_len, const unsigned char*, unsigned int ) { unsigned char** client,
//the protocol string in length prefixed byte string. the unsigned int* client_len,
//length byte is not included in the length const unsigned char*,
static unsigned char p[7] = {6,'p','o','n','i','e','s'}; unsigned int) {
// the protocol string in length prefixed byte string. the
// length byte is not included in the length
static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
*client = p; *client = p;
*client_len = 7; *client_len = 7;
return true; return true;
} }
bool clientProtoFilterPickNone(unsigned char**, unsigned int*, bool clientProtoFilterPickNone(
const unsigned char*, unsigned int) { unsigned char**,
unsigned int*,
const unsigned char*,
unsigned int) {
return false; return false;
} }
...@@ -149,12 +153,12 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) { ...@@ -149,12 +153,12 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
// Set up SSL context. // Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext()); std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
//sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem"); // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
//sslContext->authenticate(true, false); // sslContext->authenticate(true, false);
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(std::chrono::milliseconds(10000)); socket->open(std::chrono::milliseconds(10000));
// write() // write()
...@@ -269,8 +273,8 @@ TEST(AsyncSSLSocketTest, HandshakeError) { ...@@ -269,8 +273,8 @@ TEST(AsyncSSLSocketTest, HandshakeError) {
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
// read() // read()
bool ex = false; bool ex = false;
try { try {
...@@ -305,8 +309,8 @@ TEST(AsyncSSLSocketTest, ReadError) { ...@@ -305,8 +309,8 @@ TEST(AsyncSSLSocketTest, ReadError) {
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(); socket->open();
// write something to trigger ssl handshake // write something to trigger ssl handshake
...@@ -334,8 +338,8 @@ TEST(AsyncSSLSocketTest, WriteError) { ...@@ -334,8 +338,8 @@ TEST(AsyncSSLSocketTest, WriteError) {
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(); socket->open();
// write something to trigger ssl handshake // write something to trigger ssl handshake
...@@ -363,8 +367,8 @@ TEST(AsyncSSLSocketTest, SocketWithDelay) { ...@@ -363,8 +367,8 @@ TEST(AsyncSSLSocketTest, SocketWithDelay) {
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(); socket->open();
// write() // write()
...@@ -390,7 +394,9 @@ using NextProtocolTypePair = ...@@ -390,7 +394,9 @@ using NextProtocolTypePair =
class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> { class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
// For matching protos // For matching protos
public: public:
void SetUp() override { getctx(clientCtx, serverCtx); } void SetUp() override {
getctx(clientCtx, serverCtx);
}
void connect(bool unset = false) { void connect(bool unset = false) {
getfds(fds); getfds(fds);
...@@ -485,8 +491,8 @@ class NextProtocolMismatchTest : public NextProtocolTest { ...@@ -485,8 +491,8 @@ class NextProtocolMismatchTest : public NextProtocolTest {
TEST_P(NextProtocolTest, NpnTestOverlap) { TEST_P(NextProtocolTest, NpnTestOverlap) {
clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first); clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, serverCtx->setAdvertisedNextProtocols(
GetParam().second); {"foo", "bar", "baz"}, GetParam().second);
connect(); connect();
...@@ -498,8 +504,8 @@ TEST_P(NextProtocolTest, NpnTestUnset) { ...@@ -498,8 +504,8 @@ TEST_P(NextProtocolTest, NpnTestUnset) {
// Identical to above test, except that we want unset NPN before // Identical to above test, except that we want unset NPN before
// looping. // looping.
clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first); clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, serverCtx->setAdvertisedNextProtocols(
GetParam().second); {"foo", "bar", "baz"}, GetParam().second);
connect(true /* unset */); connect(true /* unset */);
...@@ -510,8 +516,8 @@ TEST_P(NextProtocolTest, NpnTestUnset) { ...@@ -510,8 +516,8 @@ TEST_P(NextProtocolTest, NpnTestUnset) {
TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) { TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first); clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, serverCtx->setAdvertisedNextProtocols(
GetParam().second); {"foo", "bar", "baz"}, GetParam().second);
connect(); connect();
...@@ -524,8 +530,8 @@ TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) { ...@@ -524,8 +530,8 @@ TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
// will fail on 1.0.2 before that. // will fail on 1.0.2 before that.
TEST_P(NextProtocolTest, NpnTestNoOverlap) { TEST_P(NextProtocolTest, NpnTestNoOverlap) {
clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, serverCtx->setAdvertisedNextProtocols(
GetParam().second); {"foo", "bar", "baz"}, GetParam().second);
connect(); connect();
if (GetParam().first == SSLContext::NextProtocolType::ALPN || if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
...@@ -539,16 +545,16 @@ TEST_P(NextProtocolTest, NpnTestNoOverlap) { ...@@ -539,16 +545,16 @@ TEST_P(NextProtocolTest, NpnTestNoOverlap) {
else if ( else if (
GetParam().first == SSLContext::NextProtocolType::ANY && GetParam().first == SSLContext::NextProtocolType::ANY &&
GetParam().second == SSLContext::NextProtocolType::ANY) { GetParam().second == SSLContext::NextProtocolType::ANY) {
# if FOLLY_OPENSSL_IS_110 #if FOLLY_OPENSSL_IS_110
// OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
// correct behavior per RFC7301 // correct behavior per RFC7301
expectHandshakeError(); expectHandshakeError();
# else #else
// BoringSSL also doesn't fatal on mismatch but behaves slightly differently // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
// from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
// NPN *and* ALPN // NPN *and* ALPN
expectNoProtocol(); expectNoProtocol();
# endif #endif
} }
#endif #endif
else { else {
...@@ -561,8 +567,8 @@ TEST_P(NextProtocolTest, NpnTestNoOverlap) { ...@@ -561,8 +567,8 @@ TEST_P(NextProtocolTest, NpnTestNoOverlap) {
TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) { TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony); clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, serverCtx->setAdvertisedNextProtocols(
GetParam().second); {"foo", "bar", "baz"}, GetParam().second);
connect(); connect();
...@@ -573,8 +579,8 @@ TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) { ...@@ -573,8 +579,8 @@ TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) { TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone); clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, serverCtx->setAdvertisedNextProtocols(
GetParam().second); {"foo", "bar", "baz"}, GetParam().second);
connect(); connect();
...@@ -587,10 +593,10 @@ TEST_P(NextProtocolTest, RandomizedNpnTest) { ...@@ -587,10 +593,10 @@ TEST_P(NextProtocolTest, RandomizedNpnTest) {
// as negligible. // as negligible.
const int kTries = 64; const int kTries = 64;
clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, clientCtx->setAdvertisedNextProtocols(
GetParam().first); {"foo", "bar", "baz"}, GetParam().first);
serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}}, serverCtx->setRandomizedAdvertisedNextProtocols(
GetParam().second); {{1, {"foo"}}, {1, {"bar"}}}, GetParam().second);
std::set<string> selectedProtocols; std::set<string> selectedProtocols;
for (int i = 0; i < kTries; ++i) { for (int i = 0; i < kTries; ++i) {
...@@ -641,16 +647,20 @@ INSTANTIATE_TEST_CASE_P( ...@@ -641,16 +647,20 @@ INSTANTIATE_TEST_CASE_P(
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
AsyncSSLSocketTest, AsyncSSLSocketTest,
NextProtocolNPNOnlyTest, NextProtocolNPNOnlyTest,
::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, ::testing::Values(NextProtocolTypePair(
SSLContext::NextProtocolType::NPN,
SSLContext::NextProtocolType::NPN))); SSLContext::NextProtocolType::NPN)));
#if FOLLY_OPENSSL_HAS_ALPN #if FOLLY_OPENSSL_HAS_ALPN
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
AsyncSSLSocketTest, AsyncSSLSocketTest,
NextProtocolMismatchTest, NextProtocolMismatchTest,
::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, ::testing::Values(
NextProtocolTypePair(
SSLContext::NextProtocolType::NPN,
SSLContext::NextProtocolType::ALPN), SSLContext::NextProtocolType::ALPN),
NextProtocolTypePair(SSLContext::NextProtocolType::ALPN, NextProtocolTypePair(
SSLContext::NextProtocolType::ALPN,
SSLContext::NextProtocolType::NPN))); SSLContext::NextProtocolType::NPN)));
#endif #endif
...@@ -678,10 +688,8 @@ TEST(AsyncSSLSocketTest, SNITestMatch) { ...@@ -678,10 +688,8 @@ TEST(AsyncSSLSocketTest, SNITestMatch) {
AsyncSSLSocket::UniquePtr serverSock( AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock)); SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock), SNIServer server(
dfServerCtx, std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
hskServerCtx,
serverName);
eventBase.loop(); eventBase.loop();
...@@ -709,15 +717,13 @@ TEST(AsyncSSLSocketTest, SNITestNotMatch) { ...@@ -709,15 +717,13 @@ TEST(AsyncSSLSocketTest, SNITestNotMatch) {
getfds(fds); getfds(fds);
getctx(clientCtx, dfServerCtx); getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock( AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
new AsyncSSLSocket(clientCtx, clientCtx, &eventBase, fds[0], clientRequestingServerName));
&eventBase,
fds[0],
clientRequestingServerName));
AsyncSSLSocket::UniquePtr serverSock( AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock)); SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock), SNIServer server(
std::move(serverSock),
dfServerCtx, dfServerCtx,
hskServerCtx, hskServerCtx,
serverExpectedServerName); serverExpectedServerName);
...@@ -747,16 +753,14 @@ TEST(AsyncSSLSocketTest, SNITestChangeServerName) { ...@@ -747,16 +753,14 @@ TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
AsyncSSLSocket::UniquePtr clientSock( AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
//Change the server name // Change the server name
std::string newName("new.com"); std::string newName("new.com");
clientSock->setServerName(newName); clientSock->setServerName(newName);
AsyncSSLSocket::UniquePtr serverSock( AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock)); SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock), SNIServer server(
dfServerCtx, std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
hskServerCtx,
serverName);
eventBase.loop(); eventBase.loop();
...@@ -785,7 +789,8 @@ TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) { ...@@ -785,7 +789,8 @@ TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
AsyncSSLSocket::UniquePtr serverSock( AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock)); SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock), SNIServer server(
std::move(serverSock),
dfServerCtx, dfServerCtx,
hskServerCtx, hskServerCtx,
serverExpectedServerName); serverExpectedServerName);
...@@ -822,7 +827,6 @@ TEST(AsyncSSLSocketTest, SSLClientTest) { ...@@ -822,7 +827,6 @@ TEST(AsyncSSLSocketTest, SSLClientTest) {
cerr << "SSLClientTest test completed" << endl; cerr << "SSLClientTest test completed" << endl;
} }
/** /**
* Test SSL client socket session re-use * Test SSL client socket session re-use
*/ */
...@@ -855,8 +859,8 @@ TEST(AsyncSSLSocketTest, SSLClientTestReuse) { ...@@ -855,8 +859,8 @@ TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) { TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
// Start listening on a local port // Start listening on a local port
EmptyReadCallback readCallback; EmptyReadCallback readCallback;
HandshakeCallback handshakeCallback(&readCallback, HandshakeCallback handshakeCallback(
HandshakeCallback::EXPECT_ERROR); &readCallback, HandshakeCallback::EXPECT_ERROR);
HandshakeTimeoutCallback acceptCallback(&handshakeCallback); HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback); TestSSLServer server(&acceptCallback);
...@@ -973,8 +977,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { ...@@ -973,8 +977,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback); ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback, HandshakeCallback handshakeCallback(
HandshakeCallback::EXPECT_ERROR); &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback); SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
TestSSLAsyncCacheServer server(&acceptCallback, 500); TestSSLAsyncCacheServer server(&acceptCallback, 500);
...@@ -987,8 +991,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { ...@@ -987,8 +991,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
EventBaseAborter eba(&eventBase, 3000); EventBaseAborter eba(&eventBase, 3000);
eventBase.loop(); eventBase.loop();
server.getEventBase().runInEventBaseThread([&handshakeCallback]{ server.getEventBase().runInEventBaseThread(
handshakeCallback.closeSocket();}); [&handshakeCallback] { handshakeCallback.closeSocket(); });
// give time for the cache lookup to come back and find it closed // give time for the cache lookup to come back and find it closed
handshakeCallback.waitForHandshake(); handshakeCallback.waitForHandshake();
...@@ -1073,7 +1077,9 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) { ...@@ -1073,7 +1077,9 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
cursor.write<uint32_t>(0); cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL(); SSL* ssl = ctx->createSSL();
SCOPE_EXIT { SSL_free(ssl); }; SCOPE_EXIT {
SSL_free(ssl);
};
AsyncSSLSocket::UniquePtr sock( AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing(); sock->enableClientHelloParsing();
...@@ -1113,7 +1119,9 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) { ...@@ -1113,7 +1119,9 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
cursor.write<uint32_t>(0); cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL(); SSL* ssl = ctx->createSSL();
SCOPE_EXIT { SSL_free(ssl); }; SCOPE_EXIT {
SSL_free(ssl);
};
AsyncSSLSocket::UniquePtr sock( AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing(); sock->enableClientHelloParsing();
...@@ -1121,13 +1129,23 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) { ...@@ -1121,13 +1129,23 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
// Test parsing with two packets with first packet size < 3 // Test parsing with two packets with first packet size < 3
auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2); auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
AsyncSSLSocket::clientHelloParsingCallback( AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), 0,
ssl, sock.get()); 0,
SSL3_RT_HANDSHAKE,
bufCopy->data(),
bufCopy->length(),
ssl,
sock.get());
bufCopy.reset(); bufCopy.reset();
bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2); bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
AsyncSSLSocket::clientHelloParsingCallback( AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), 0,
ssl, sock.get()); 0,
SSL3_RT_HANDSHAKE,
bufCopy->data(),
bufCopy->length(),
ssl,
sock.get());
bufCopy.reset(); bufCopy.reset();
auto parsedClientHello = sock->getClientHelloInfo(); auto parsedClientHello = sock->getClientHelloInfo();
...@@ -1160,7 +1178,9 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) { ...@@ -1160,7 +1178,9 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
cursor.write<uint32_t>(0); cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL(); SSL* ssl = ctx->createSSL();
SCOPE_EXIT { SSL_free(ssl); }; SCOPE_EXIT {
SSL_free(ssl);
};
AsyncSSLSocket::UniquePtr sock( AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true)); new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing(); sock->enableClientHelloParsing();
...@@ -1170,8 +1190,13 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) { ...@@ -1170,8 +1190,13 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
auto bufCopy = folly::IOBuf::copyBuffer( auto bufCopy = folly::IOBuf::copyBuffer(
buf->data() + i, std::min((uint64_t)3, buf->length() - i)); buf->data() + i, std::min((uint64_t)3, buf->length() - i));
AsyncSSLSocket::clientHelloParsingCallback( AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(), 0,
ssl, sock.get()); 0,
SSL3_RT_HANDSHAKE,
bufCopy->data(),
bufCopy->length(),
ssl,
sock.get());
bufCopy.reset(); bufCopy.reset();
} }
...@@ -1458,7 +1483,6 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) { ...@@ -1458,7 +1483,6 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
EXPECT_LE(0, server.handshakeTime.count()); EXPECT_LE(0, server.handshakeTime.count());
} }
/** /**
* Test requireClientCert with no client cert * Test requireClientCert with no client cert
*/ */
...@@ -1557,9 +1581,8 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) { ...@@ -1557,9 +1581,8 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
class ReadCallbackTerminator : public ReadCallback { class ReadCallbackTerminator : public ReadCallback {
public: public:
ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb) ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
: ReadCallback(wcb) : ReadCallback(wcb), base_(base) {}
, base_(base) {}
// Do not write data back, terminate the loop. // Do not write data back, terminate the loop.
void readDataAvailable(size_t len) noexcept override { void readDataAvailable(size_t len) noexcept override {
...@@ -1574,11 +1597,11 @@ class ReadCallbackTerminator : public ReadCallback { ...@@ -1574,11 +1597,11 @@ class ReadCallbackTerminator : public ReadCallback {
socket_->setReadCB(nullptr); socket_->setReadCB(nullptr);
base_->terminateLoopSoon(); base_->terminateLoopSoon();
} }
private: private:
EventBase* base_; EventBase* base_;
}; };
/** /**
* Test a full unencrypted codepath * Test a full unencrypted codepath
*/ */
...@@ -1590,10 +1613,9 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) { ...@@ -1590,10 +1613,9 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) {
int fds[2]; int fds[2];
getfds(fds); getfds(fds);
getctx(clientCtx, serverCtx); getctx(clientCtx, serverCtx);
auto client = AsyncSSLSocket::newSocket( auto client =
clientCtx, &base, fds[0], false, true); AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
auto server = AsyncSSLSocket::newSocket( auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
serverCtx, &base, fds[1], true, true);
ReadCallbackTerminator readCallback(&base, nullptr); ReadCallbackTerminator readCallback(&base, nullptr);
server->setReadCB(&readCallback); server->setReadCB(&readCallback);
...@@ -1629,7 +1651,6 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) { ...@@ -1629,7 +1651,6 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) {
EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK)); EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
EXPECT_NE('a', c2); EXPECT_NE('a', c2);
base.loop(); base.loop();
EXPECT_EQ(2, readCallback.buffers.size()); EXPECT_EQ(2, readCallback.buffers.size());
...@@ -1671,8 +1692,8 @@ TEST(AsyncSSLSocketTest, ConnResetErrorString) { ...@@ -1671,8 +1692,8 @@ TEST(AsyncSSLSocketTest, ConnResetErrorString) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback); WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback, HandshakeCallback handshakeCallback(
HandshakeCallback::EXPECT_ERROR); &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback); TestSSLServer server(&acceptCallback);
...@@ -1692,8 +1713,8 @@ TEST(AsyncSSLSocketTest, ConnEOFErrorString) { ...@@ -1692,8 +1713,8 @@ TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback); WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback, HandshakeCallback handshakeCallback(
HandshakeCallback::EXPECT_ERROR); &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback); TestSSLServer server(&acceptCallback);
...@@ -1717,8 +1738,8 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) { ...@@ -1717,8 +1738,8 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback); WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback, HandshakeCallback handshakeCallback(
HandshakeCallback::EXPECT_ERROR); &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback); SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback); TestSSLServer server(&acceptCallback);
...@@ -1730,17 +1751,19 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) { ...@@ -1730,17 +1751,19 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
socket->close(); socket->close();
handshakeCallback.waitForHandshake(); handshakeCallback.waitForHandshake();
EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"), EXPECT_NE(
std::string::npos); handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
#if defined(OPENSSL_IS_BORINGSSL) #if defined(OPENSSL_IS_BORINGSSL)
EXPECT_NE( EXPECT_NE(
handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"), handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
std::string::npos); std::string::npos);
#elif FOLLY_OPENSSL_IS_110 #elif FOLLY_OPENSSL_IS_110
EXPECT_NE(handshakeCallback.errorString_.find("packet length too long"), EXPECT_NE(
handshakeCallback.errorString_.find("packet length too long"),
std::string::npos); std::string::npos);
#else #else
EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"), EXPECT_NE(
handshakeCallback.errorString_.find("unknown protocol"),
std::string::npos); std::string::npos);
#endif #endif
} }
...@@ -2077,8 +2100,8 @@ TEST(AsyncSSLSocketTest, SendMsgParamsCallback) { ...@@ -2077,8 +2100,8 @@ TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(); socket->open();
// Setting flags to "-1" to trigger "Invalid argument" error // Setting flags to "-1" to trigger "Invalid argument" error
...@@ -2129,13 +2152,13 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { ...@@ -2129,13 +2152,13 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect // connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), auto socket =
sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(); socket->open();
// Adding MSG_EOR flag to the message flags - it'll trigger // Adding MSG_EOR flag to the message flags - it'll trigger
// timestamp generation for the last byte of the message. // timestamp generation for the last byte of the message.
msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR); msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
// Init ancillary data buffer to trigger timestamp notification // Init ancillary data buffer to trigger timestamp notification
union { union {
...@@ -2145,9 +2168,7 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { ...@@ -2145,9 +2168,7 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
u.cmsg.cmsg_level = SOL_SOCKET; u.cmsg.cmsg_level = SOL_SOCKET;
u.cmsg.cmsg_type = SO_TIMESTAMPING; u.cmsg.cmsg_type = SO_TIMESTAMPING;
u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t)); u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
uint32_t flags = uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
SOF_TIMESTAMPING_TX_SCHED |
SOF_TIMESTAMPING_TX_SOFTWARE |
SOF_TIMESTAMPING_TX_ACK; SOF_TIMESTAMPING_TX_ACK;
memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t)); memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t))); std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment