Commit 6687bf85 authored by Alex Guzman's avatar Alex Guzman Committed by Facebook Github Bot

Pend free of SSL in AsyncSSLSocket until async callback completion.

Summary: Pends the freeing of the internal SSL until the socket is finally destroyed. This ensures that the async job can write out the result and call the socket's callback. This also always calls restartSSLAccept in order to let it handle errors and cleaning up of async jobs.

Reviewed By: knekritz

Differential Revision: D9599917

fbshipit-source-id: 8c4ce8b762fe59f08c2a40e76a0bebe59cd2929e
parent d96d9550
This diff is collapsed.
...@@ -157,20 +157,22 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -157,20 +157,22 @@ class AsyncSSLSocket : public virtual AsyncSocket {
public: public:
DefaultOpenSSLAsyncFinishCallback( DefaultOpenSSLAsyncFinishCallback(
AsyncPipeReader::UniquePtr reader, AsyncPipeReader::UniquePtr reader,
AsyncSSLSocket* sslSocket) AsyncSSLSocket* sslSocket,
: pipeReader_(std::move(reader)), sslSocket_(sslSocket) {} DestructorGuard dg)
: pipeReader_(std::move(reader)),
sslSocket_(sslSocket),
dg_(std::move(dg)) {}
~DefaultOpenSSLAsyncFinishCallback() {
pipeReader_->setReadCB(nullptr);
sslSocket_->setAsyncOperationFinishCallback(nullptr);
}
void readDataAvailable(size_t len) noexcept override { void readDataAvailable(size_t len) noexcept override {
CHECK_EQ(len, 1); CHECK_EQ(len, 1);
if (byte_ > 0) {
sslSocket_->restartSSLAccept(); sslSocket_->restartSSLAccept();
} else {
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
"Error with asynchronous crypto operation");
sslSocket_->failHandshake(__func__, ex);
}
pipeReader_->setReadCB(nullptr); pipeReader_->setReadCB(nullptr);
sslSocket_->setAsyncOperationFinishCallback(nullptr);
} }
void getReadBuffer(void** bufReturn, size_t* lenReturn) noexcept override { void getReadBuffer(void** bufReturn, size_t* lenReturn) noexcept override {
...@@ -186,6 +188,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -186,6 +188,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
uint8_t byte_{0}; uint8_t byte_{0};
AsyncPipeReader::UniquePtr pipeReader_; AsyncPipeReader::UniquePtr pipeReader_;
AsyncSSLSocket* sslSocket_{nullptr}; AsyncSSLSocket* sslSocket_{nullptr};
DestructorGuard dg_;
}; };
/** /**
...@@ -861,7 +864,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -861,7 +864,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* applied. If verifyPeer_ was explicitly set either via sslConn/sslAccept, * applied. If verifyPeer_ was explicitly set either via sslConn/sslAccept,
* those options override the settings in the underlying SSLContext. * those options override the settings in the underlying SSLContext.
*/ */
void applyVerificationOptions(SSL* ssl); void applyVerificationOptions(const ssl::SSLUniquePtr& ssl);
/** /**
* Sets up SSL with a custom write bio which intercepts all writes. * Sets up SSL with a custom write bio which intercepts all writes.
...@@ -873,13 +876,17 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -873,13 +876,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
/** /**
* A SSL_write wrapper that understand EOR * A SSL_write wrapper that understand EOR
* *
* @param ssl: SSL* object * @param ssl: SSL pointer
* @param buf: Buffer to be written * @param buf: Buffer to be written
* @param n: Number of bytes to be written * @param n: Number of bytes to be written
* @param eor: Does the last byte (buf[n-1]) have the app-last-byte? * @param eor: Does the last byte (buf[n-1]) have the app-last-byte?
* @return: The number of app bytes successfully written to the socket * @return: The number of app bytes successfully written to the socket
*/ */
int eorAwareSSLWrite(SSL* ssl, const void* buf, int n, bool eor); int eorAwareSSLWrite(
const ssl::SSLUniquePtr& ssl,
const void* buf,
int n,
bool eor);
// Inherit error handling methods from AsyncSocket, plus the following. // Inherit error handling methods from AsyncSocket, plus the following.
void failHandshake(const char* fn, const AsyncSocketException& ex); void failHandshake(const char* fn, const AsyncSocketException& ex);
...@@ -909,7 +916,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -909,7 +916,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
std::shared_ptr<folly::SSLContext> ctx_; std::shared_ptr<folly::SSLContext> ctx_;
// Callback for SSL_accept() or SSL_connect() // Callback for SSL_accept() or SSL_connect()
HandshakeCB* handshakeCallback_{nullptr}; HandshakeCB* handshakeCallback_{nullptr};
SSL* ssl_{nullptr}; ssl::SSLUniquePtr ssl_;
SSL_SESSION* sslSession_{nullptr}; SSL_SESSION* sslSession_{nullptr};
Timeout handshakeTimeout_; Timeout handshakeTimeout_;
Timeout connectionTimeout_; Timeout connectionTimeout_;
......
...@@ -1496,6 +1496,7 @@ static void makeNonBlockingPipe(int pipefds[2]) { ...@@ -1496,6 +1496,7 @@ static void makeNonBlockingPipe(int pipefds[2]) {
// Custom RSA private key encryption method // Custom RSA private key encryption method
static int kRSAExIndex = -1; static int kRSAExIndex = -1;
static int kRSAEvbExIndex = -1; static int kRSAEvbExIndex = -1;
static int kRSASocketExIndex = -1;
static constexpr StringPiece kEngineId = "AsyncSSLSocketTest"; static constexpr StringPiece kEngineId = "AsyncSSLSocketTest";
static int customRsaPrivEnc( static int customRsaPrivEnc(
...@@ -1512,6 +1513,9 @@ static int customRsaPrivEnc( ...@@ -1512,6 +1513,9 @@ static int customRsaPrivEnc(
RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex)); RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex));
CHECK(actualRSA); CHECK(actualRSA);
AsyncSSLSocket* socket = reinterpret_cast<AsyncSSLSocket*>(
RSA_get_ex_data(rsa, kRSASocketExIndex));
ASYNC_JOB* job = ASYNC_get_current_job(); ASYNC_JOB* job = ASYNC_get_current_job();
if (job == nullptr) { if (job == nullptr) {
throw std::runtime_error("Expected call in job context"); throw std::runtime_error("Expected call in job context");
...@@ -1535,8 +1539,13 @@ static int customRsaPrivEnc( ...@@ -1535,8 +1539,13 @@ static int customRsaPrivEnc(
to = to, to = to,
padding = padding, padding = padding,
actualRSA = actualRSA, actualRSA = actualRSA,
writer = asyncPipeWriter.get()]() { writer = std::move(asyncPipeWriter),
socket = socket]() {
LOG(INFO) << "Running job"; LOG(INFO) << "Running job";
if (socket) {
LOG(INFO) << "Got a socket passed in, closing it...";
socket->closeNow();
}
*retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())( *retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())(
flen, from, to, actualRSA, padding); flen, from, to, actualRSA, padding);
LOG(INFO) << "Finished job, writing to pipe"; LOG(INFO) << "Finished job, writing to pipe";
...@@ -1634,8 +1643,11 @@ setupCustomRSA(const char* certPath, const char* keyPath, EventBase* jobEvb) { ...@@ -1634,8 +1643,11 @@ setupCustomRSA(const char* certPath, const char* keyPath, EventBase* jobEvb) {
kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
kRSASocketExIndex =
RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
CHECK_NE(kRSAExIndex, -1); CHECK_NE(kRSAExIndex, -1);
CHECK_NE(kRSAEvbExIndex, -1); CHECK_NE(kRSAEvbExIndex, -1);
CHECK_NE(kRSASocketExIndex, -1);
RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa); RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa);
RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb); RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb);
...@@ -1724,6 +1736,47 @@ TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) { ...@@ -1724,6 +1736,47 @@ TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
EXPECT_TRUE(client.handshakeError_); EXPECT_TRUE(client.handshakeError_);
ASYNC_cleanup_thread(); ASYNC_cleanup_thread();
} }
TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestClosedWithCallbackPending) {
ASYNC_init_thread(1, 1);
EventBase eventBase;
ScopedEventBaseThread jobEvbThread;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadCertificate(kTestCert);
serverCtx->loadTrustedCertificates(kTestCA);
serverCtx->loadClientCAList(kTestCA);
auto rsaPointers =
setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
CHECK(rsaPointers->dummyrsa);
// up-refs dummyrsa
SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
RSA_set_ex_data(rsaPointers->dummyrsa, kRSASocketExIndex, serverSock.get());
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServer server(std::move(serverSock), false, false);
eventBase.loop();
EXPECT_TRUE(server.handshakeError_);
EXPECT_TRUE(client.handshakeError_);
ASYNC_cleanup_thread();
}
#endif // FOLLY_SANITIZE_ADDRESS #endif // FOLLY_SANITIZE_ADDRESS
#endif // FOLLY_OPENSSL_IS_110 #endif // FOLLY_OPENSSL_IS_110
......
...@@ -35,8 +35,8 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { ...@@ -35,8 +35,8 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
EventBase* evb) { EventBase* evb) {
auto sock = std::shared_ptr<MockAsyncSSLSocket>( auto sock = std::shared_ptr<MockAsyncSSLSocket>(
new MockAsyncSSLSocket(ctx, evb), Destructor()); new MockAsyncSSLSocket(ctx, evb), Destructor());
sock->ssl_ = SSL_new(ctx->getSSLCtx()); sock->ssl_.reset(SSL_new(ctx->getSSLCtx()));
SSL_set_fd(sock->ssl_, -1); SSL_set_fd(sock->ssl_.get(), -1);
return sock; return sock;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment