Commit 72e652d1 authored by Anirudh Ramachandran's avatar Anirudh Ramachandran Committed by Facebook Github Bot

Make AsyncSSLSocket aware of OpenSSL 1.1.0's async API

Summary:
OpenSSL 1.1.0 uses a fiber-based (makecontext/swapcontext) API to do
asynchronous operations. When some operation deep inside the stack calls
ASYNC_pause_job, SSL_accept returns -1 with error SSL_ERROR_WANT_ASYNC.
OpenSSL chose to use fds to wait on, so after SSL_accept returns, we create an
AsyncPipeReader to restart SSL_accept when the pipe becomes readable, which is our
indication that the async job processing has finished.

Also implemented a test to kick off an async job in a different thread that creates a pipe
and gives the read end back to the SSL* before calling ASYNC_pause_job

Reviewed By: yfeldblum

Differential Revision: D5977514

fbshipit-source-id: 3aba2e45b9357dc28cf7cf785654072f8ba8dd65
parent 03b404df
......@@ -975,6 +975,9 @@ bool AsyncSSLSocket::willBlock(int ret,
#endif
#ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
|| error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
#endif
#ifdef SSL_ERROR_WANT_ASYNC // OpenSSL 1.1.0 Async API
|| error == SSL_ERROR_WANT_ASYNC
#endif
)) {
// Our custom openssl function has kicked off an async request to do
......@@ -988,6 +991,35 @@ bool AsyncSSLSocket::willBlock(int ret,
EventHandler::READ | EventHandler::WRITE
);
#ifdef SSL_ERROR_WANT_ASYNC
if (error == SSL_ERROR_WANT_ASYNC) {
size_t numfds;
if (SSL_get_all_async_fds(ssl_, NULL, &numfds) <= 0) {
VLOG(4) << "SSL_ERROR_WANT_ASYNC but no async FDs set!";
return false;
}
if (numfds != 1) {
VLOG(4) << "SSL_ERROR_WANT_ASYNC expected exactly 1 async fd, got "
<< numfds;
return false;
}
OSSL_ASYNC_FD ofd; // This should just be an int in POSIX
if (SSL_get_all_async_fds(ssl_, &ofd, &numfds) <= 0) {
VLOG(4) << "SSL_ERROR_WANT_ASYNC cant get async fd";
return false;
}
auto asyncPipeReader = AsyncPipeReader::newReader(eventBase_, ofd);
auto asyncPipeReaderPtr = asyncPipeReader.get();
if (!asyncOperationFinishCallback_) {
asyncOperationFinishCallback_.reset(
new DefaultOpenSSLAsyncFinishCallback(
std::move(asyncPipeReader), this));
}
asyncPipeReaderPtr->setReadCB(asyncOperationFinishCallback_.get());
}
#endif
// The timeout (if set) keeps running here
return true;
} else {
......@@ -1086,6 +1118,7 @@ AsyncSSLSocket::handleAccept() noexcept {
int ret = SSL_accept(ssl_);
if (ret <= 0) {
VLOG(3) << "SSL_accept returned: " << ret;
int sslError;
unsigned long errError;
int errnoCopy = errno;
......
......@@ -22,6 +22,7 @@
#include <folly/String.h>
#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
#include <folly/io/async/AsyncPipe.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/SSLContext.h>
......@@ -149,11 +150,52 @@ class AsyncSSLSocket : public virtual AsyncSocket {
TimeoutManager::timeout_type timeout_;
};
/**
* A class to wait for asynchronous operations with OpenSSL 1.1.0
*/
class DefaultOpenSSLAsyncFinishCallback : public ReadCallback {
public:
DefaultOpenSSLAsyncFinishCallback(
AsyncPipeReader::UniquePtr reader,
AsyncSSLSocket* sslSocket)
: pipeReader_(std::move(reader)),
sslSocket_(sslSocket) {}
void readDataAvailable(size_t len) noexcept override {
CHECK_EQ(len, 1);
if (byte_ > 0) {
sslSocket_->restartSSLAccept();
} else {
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
"Error with asynchronous crypto operation");
sslSocket_->failHandshake(__func__, ex);
}
pipeReader_->setReadCB(nullptr);
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) noexcept override {
*bufReturn = &byte_;
*lenReturn = 1;
}
void readEOF() noexcept override {}
void readErr(const folly::AsyncSocketException&) noexcept override {}
private:
uint8_t byte_{0};
AsyncPipeReader::UniquePtr pipeReader_;
AsyncSSLSocket* sslSocket_{nullptr};
};
/**
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb, bool deferSecurityNegotiation = false);
AsyncSSLSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
bool deferSecurityNegotiation = false);
/**
* Create a server/client AsyncSSLSocket from an already connected
......@@ -769,6 +811,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
return totalConnectTimeout_;
}
// This can be called for OpenSSL 1.1.0 async operation finishes
void setAsyncOperationFinishCallback(std::unique_ptr<ReadCallback> cb) {
asyncOperationFinishCallback_ = std::move(cb);
}
private:
void init();
......@@ -925,6 +972,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
bool sessionResumptionAttempted_{false};
// whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false};
// This can be called for OpenSSL 1.1.0 async operation finishes
std::unique_ptr<ReadCallback> asyncOperationFinishCallback_;
};
} // namespace folly
......@@ -17,13 +17,16 @@
#include <folly/SocketAddress.h>
#include <folly/io/Cursor.h>
#include <folly/io/async/AsyncPipe.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
#include <folly/portability/OpenSSL.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/ssl/Init.h>
#include <folly/io/async/test/BlockingSocket.h>
......@@ -37,6 +40,10 @@
#include <set>
#include <thread>
#if FOLLY_OPENSSL_IS_110
#include <openssl/async.h>
#endif
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
#include <sys/utsname.h>
#endif
......@@ -1594,6 +1601,258 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
EXPECT_LE(0, server.handshakeTime.count());
}
/**
* Test OpenSSL 1.1.0's async functionality
*/
#if FOLLY_OPENSSL_IS_110
static void makeNonBlockingPipe(int pipefds[2]) {
if (pipe(pipefds) != 0) {
throw std::runtime_error("Cannot create pipe");
}
if (::fcntl(pipefds[0], F_SETFL, O_NONBLOCK) != 0) {
throw std::runtime_error("Cannot set pipe to nonblocking");
}
if (::fcntl(pipefds[1], F_SETFL, O_NONBLOCK) != 0) {
throw std::runtime_error("Cannot set pipe to nonblocking");
}
}
// Custom RSA private key encryption method
static int kRSAExIndex = -1;
static int kRSAEvbExIndex = -1;
static constexpr StringPiece kEngineId = "AsyncSSLSocketTest";
static int customRsaPrivEnc(
int flen,
const unsigned char* from,
unsigned char* to,
RSA* rsa,
int padding) {
LOG(INFO) << "rsa_priv_enc";
EventBase* asyncJobEvb =
reinterpret_cast<EventBase*>(RSA_get_ex_data(rsa, kRSAEvbExIndex));
CHECK(asyncJobEvb);
RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex));
CHECK(actualRSA);
ASYNC_JOB* job = ASYNC_get_current_job();
if (job == nullptr) {
throw std::runtime_error("Expected call in job context");
}
ASYNC_WAIT_CTX* waitctx = ASYNC_get_wait_ctx(job);
OSSL_ASYNC_FD pipefds[2] = {0, 0};
makeNonBlockingPipe(pipefds);
if (!ASYNC_WAIT_CTX_set_wait_fd(
waitctx, kEngineId.data(), pipefds[0], nullptr, nullptr)) {
throw std::runtime_error("Cannot set wait fd");
}
int ret = 0;
int* retptr = &ret;
auto asyncPipeWriter =
folly::AsyncPipeWriter::newWriter(asyncJobEvb, pipefds[1]);
asyncJobEvb->runInEventBaseThread([retptr = retptr,
flen = flen,
from = from,
to = to,
padding = padding,
actualRSA = actualRSA,
writer = asyncPipeWriter.get()]() {
LOG(INFO) << "Running job";
*retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())(
flen, from, to, actualRSA, padding);
LOG(INFO) << "Finished job, writing to pipe";
uint8_t byte = *retptr > 0 ? 1 : 0;
writer->write(nullptr, &byte, 1);
});
LOG(INFO) << "About to pause job";
ASYNC_pause_job();
LOG(INFO) << "Resumed job with ret: " << ret;
return ret;
}
void rsaFree(void*, void* ptr, CRYPTO_EX_DATA*, int, long, void*) {
LOG(INFO) << "RSA_free is called with ptr " << std::hex << ptr;
if (ptr == nullptr) {
LOG(INFO) << "Returning early from rsaFree because ptr is null";
return;
}
RSA* rsa = (RSA*)ptr;
auto meth = RSA_get_method(rsa);
if (meth != RSA_get_default_method()) {
auto nonconst = const_cast<RSA_METHOD*>(meth);
RSA_meth_free(nonconst);
RSA_set_method(rsa, RSA_get_default_method());
}
RSA_free(rsa);
}
struct RSAPointers {
RSA* actualrsa{nullptr};
RSA* dummyrsa{nullptr};
RSA_METHOD* meth{nullptr};
};
inline void RSAPointersFree(RSAPointers* p) {
if (p->meth && p->dummyrsa && RSA_get_method(p->dummyrsa) == p->meth) {
RSA_set_method(p->dummyrsa, RSA_get_default_method());
}
if (p->meth) {
LOG(INFO) << "Freeing meth";
RSA_meth_free(p->meth);
}
if (p->actualrsa) {
LOG(INFO) << "Freeing actualrsa";
RSA_free(p->actualrsa);
}
if (p->dummyrsa) {
LOG(INFO) << "Freeing dummyrsa";
RSA_free(p->dummyrsa);
}
delete p;
}
using RSAPointersDeleter =
folly::static_function_deleter<RSAPointers, RSAPointersFree>;
std::unique_ptr<RSAPointers, RSAPointersDeleter>
setupCustomRSA(const char* certPath, const char* keyPath, EventBase* jobEvb) {
auto certPEM = getFileAsBuf(certPath);
auto keyPEM = getFileAsBuf(keyPath);
ssl::BioUniquePtr certBio(
BIO_new_mem_buf((void*)certPEM.data(), certPEM.size()));
ssl::BioUniquePtr keyBio(
BIO_new_mem_buf((void*)keyPEM.data(), keyPEM.size()));
ssl::X509UniquePtr cert(
PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
ssl::EvpPkeyUniquePtr evpPkey(
PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
ssl::EvpPkeyUniquePtr publicEvpPkey(X509_get_pubkey(cert.get()));
std::unique_ptr<RSAPointers, RSAPointersDeleter> ret(new RSAPointers());
RSA* actualrsa = EVP_PKEY_get1_RSA(evpPkey.get());
LOG(INFO) << "actualrsa ptr " << std::hex << (void*)actualrsa;
RSA* dummyrsa = EVP_PKEY_get1_RSA(publicEvpPkey.get());
if (dummyrsa == nullptr) {
throw std::runtime_error("Couldn't get RSA cert public factors");
}
RSA_METHOD* meth = RSA_meth_dup(RSA_get_default_method());
if (meth == nullptr || RSA_meth_set1_name(meth, "Async RSA method") == 0 ||
RSA_meth_set_priv_enc(meth, customRsaPrivEnc) == 0 ||
RSA_meth_set_flags(meth, RSA_METHOD_FLAG_NO_CHECK) == 0) {
throw std::runtime_error("Cannot create async RSA_METHOD");
}
RSA_set_method(dummyrsa, meth);
RSA_set_flags(dummyrsa, RSA_FLAG_EXT_PKEY);
kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
CHECK_NE(kRSAExIndex, -1);
CHECK_NE(kRSAEvbExIndex, -1);
RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa);
RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb);
ret->actualrsa = actualrsa;
ret->dummyrsa = dummyrsa;
ret->meth = meth;
return ret;
}
// TODO: disabled with ASAN doesn't play nice with ASYNC for some reason
#ifndef FOLLY_SANITIZE_ADDRESS
TEST(AsyncSSLSocketTest, OpenSSL110AsyncTest) {
ASYNC_init_thread(1, 1);
EventBase eventBase;
ScopedEventBaseThread jobEvbThread;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadCertificate(kTestCert);
serverCtx->loadTrustedCertificates(kTestCA);
serverCtx->loadClientCAList(kTestCA);
auto rsaPointers =
setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
CHECK(rsaPointers->dummyrsa);
// up-refs dummyrsa
SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServer server(std::move(serverSock), false, false);
eventBase.loop();
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(client.handshakeSuccess_);
ASYNC_cleanup_thread();
}
TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
ASYNC_init_thread(1, 1);
EventBase eventBase;
ScopedEventBaseThread jobEvbThread;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadCertificate(kTestCert);
serverCtx->loadTrustedCertificates(kTestCA);
serverCtx->loadClientCAList(kTestCA);
// Set the wrong key for the cert
auto rsaPointers =
setupCustomRSA(kTestCert, kClientTestKey, jobEvbThread.getEventBase());
CHECK(rsaPointers->dummyrsa);
SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServer server(std::move(serverSock), false, false);
eventBase.loop();
EXPECT_TRUE(server.handshakeError_);
EXPECT_TRUE(client.handshakeError_);
ASYNC_cleanup_thread();
}
#endif // FOLLY_SANITIZE_ADDRESS
#endif // FOLLY_OPENSSL_IS_110
TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
auto cert = getFileAsBuf(kTestCert);
auto key = getFileAsBuf(kTestKey);
......
......@@ -23,6 +23,7 @@
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/portability/GTest.h>
#include <folly/portability/PThread.h>
#include <folly/ssl/Init.h>
using std::string;
using std::vector;
......@@ -280,10 +281,12 @@ TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
} // namespace folly
int main(int argc, char *argv[]) {
folly::ssl::init();
#ifdef SIGPIPE
signal(SIGPIPE, SIG_IGN);
#endif
testing::InitGoogleTest(&argc, argv);
folly::init(&argc, &argv);
return RUN_ALL_TESTS();
OPENSSL_cleanup();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment