Commit d96d9550 authored by Alex Guzman's avatar Alex Guzman Committed by Facebook Github Bot

Let AsyncSSLSocket run accept via a runner.

Summary: Allows for a runner to execute the accept function and return the result via a callback. If no runner is specified, it simply does the accept as usual.

Reviewed By: yfeldblum

Differential Revision: D9849138

fbshipit-source-id: ef43ccc8668bcf1fe7f75b0b6fdcdba7adc891da
parent fb28875b
......@@ -1144,7 +1144,20 @@ void AsyncSSLSocket::handleAccept() noexcept {
SSL_set_msg_callback_arg(ssl_, this);
}
int ret = SSL_accept(ssl_);
DCHECK(ctx_->sslAcceptRunner());
updateEventRegistration(
EventHandler::NONE, EventHandler::READ | EventHandler::WRITE);
DelayedDestruction::DestructorGuard dg(this);
ctx_->sslAcceptRunner()->run(
[this, dg]() { return SSL_accept(ssl_); },
[this, dg](int ret) { handleReturnFromSSLAccept(ret); });
}
void AsyncSSLSocket::handleReturnFromSSLAccept(int ret) {
if (sslState_ != STATE_ACCEPTING) {
return;
}
if (ret <= 0) {
VLOG(3) << "SSL_accept returned: " << ret;
int sslError;
......
......@@ -799,6 +799,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
}
private:
/**
* Handle the return from invoking SSL_accept
*/
void handleReturnFromSSLAccept(int ret);
void init();
protected:
......
......@@ -66,6 +66,8 @@ SSLContext::SSLContext(SSLVersion version) {
SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
sslAcceptRunner_ = std::make_unique<SSLAcceptRunner>();
#if FOLLY_OPENSSL_HAS_SNI
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this);
......
......@@ -30,6 +30,7 @@
#include <folly/folly-config.h>
#endif
#include <folly/Function.h>
#include <folly/Portability.h>
#include <folly/Range.h>
#include <folly/String.h>
......@@ -64,6 +65,24 @@ class PasswordCollector {
virtual std::string describe() const = 0;
};
/**
* Run SSL_accept via a runner
*/
class SSLAcceptRunner {
public:
virtual ~SSLAcceptRunner() = default;
/**
* This is expected to run the first function and provide its return
* value to the second function. This can be used to run the SSL_accept
* in different contexts.
*/
virtual void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const {
finallyFunc(acceptFunc());
}
};
/**
* Wrap OpenSSL SSL_CTX into a class.
*/
......@@ -509,6 +528,22 @@ class SSLContext {
void enableFalseStart();
#endif
/**
* Sets the runner used for SSL_accept. If none is given, the accept will be
* done directly.
*/
void sslAcceptRunner(std::unique_ptr<SSLAcceptRunner> runner) {
if (nullptr == runner) {
LOG(ERROR) << "Ignore invalid runner";
return;
}
sslAcceptRunner_ = std::move(runner);
}
const SSLAcceptRunner* sslAcceptRunner() {
return sslAcceptRunner_.get();
}
/**
* Helper to match a hostname versus a pattern.
*/
......@@ -534,6 +569,8 @@ class SSLContext {
static bool initialized_;
std::unique_ptr<SSLAcceptRunner> sslAcceptRunner_;
#if FOLLY_OPENSSL_HAS_ALPN
struct AdvertisedNextProtocolsItem {
......
......@@ -1897,6 +1897,146 @@ TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
socket->close();
}
/**
* Test acceptrunner in various situations
*/
TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
serverCtx->sslAcceptRunner(std::make_unique<SSLAcceptEvbRunner>(&eventBase));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptErrorRunner>(&eventBase));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_FALSE(client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptCloseRunner>(&eventBase, serverSock.get()));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_FALSE(client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptDestroyRunner>(&eventBase, &server));
eventBase.loop();
EXPECT_FALSE(client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, ConnResetErrorString) {
// Start listening on a local port
WriteCallbackBase writeCallback;
......
......@@ -1309,7 +1309,9 @@ class SSLHandshakeBase : public AsyncSSLSocket::HandshakeCB,
void handshakeSuc(AsyncSSLSocket*) noexcept override {
LOG(INFO) << "Handshake success";
handshakeSuccess_ = true;
handshakeTime = socket_->getHandshakeTime();
if (socket_) {
handshakeTime = socket_->getHandshakeTime();
}
}
void handshakeErr(
......@@ -1317,12 +1319,16 @@ class SSLHandshakeBase : public AsyncSSLSocket::HandshakeCB,
const AsyncSocketException& ex) noexcept override {
LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true;
handshakeTime = socket_->getHandshakeTime();
if (socket_) {
handshakeTime = socket_->getHandshakeTime();
}
}
// WriteCallback
void writeSuccess() noexcept override {
socket_->close();
if (socket_) {
socket_->close();
}
}
void writeErr(
......@@ -1451,4 +1457,75 @@ class EventBaseAborter : public AsyncTimeout {
EventBase* eventBase_;
};
class SSLAcceptEvbRunner : public SSLAcceptRunner {
public:
explicit SSLAcceptEvbRunner(EventBase* evb) : evb_(evb) {}
~SSLAcceptEvbRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop([acceptFunc = std::move(acceptFunc),
finallyFunc = std::move(finallyFunc)]() mutable {
finallyFunc(acceptFunc());
});
}
protected:
EventBase* evb_;
};
class SSLAcceptErrorRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptErrorRunner(EventBase* evb) : SSLAcceptEvbRunner(evb) {}
~SSLAcceptErrorRunner() override = default;
void run(Function<int()> /*acceptFunc*/, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop(
[finallyFunc = std::move(finallyFunc)]() mutable { finallyFunc(-1); });
}
};
class SSLAcceptCloseRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptCloseRunner(EventBase* evb, folly::AsyncSSLSocket* sock)
: SSLAcceptEvbRunner(evb), socket_(sock) {}
~SSLAcceptCloseRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop([acceptFunc = std::move(acceptFunc),
finallyFunc = std::move(finallyFunc),
sock = socket_]() mutable {
auto ret = acceptFunc();
sock->closeNow();
finallyFunc(ret);
});
}
private:
folly::AsyncSSLSocket* socket_;
};
class SSLAcceptDestroyRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptDestroyRunner(EventBase* evb, SSLHandshakeBase* base)
: SSLAcceptEvbRunner(evb), sslBase_(base) {}
~SSLAcceptDestroyRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop([acceptFunc = std::move(acceptFunc),
finallyFunc = std::move(finallyFunc),
sslBase = sslBase_]() mutable {
auto ret = acceptFunc();
std::move(*sslBase).moveSocket();
finallyFunc(ret);
});
}
private:
SSLHandshakeBase* sslBase_;
};
} // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment