Commit d65b7736 authored by Aaron Balsara's avatar Aaron Balsara Committed by facebook-github-bot-0

Allow SSLContext to read certificates and keys from memory

Summary: Added the ability for SSLContext to load X509 Certificates and private keys from memory

Reviewed By: yfeldblum

Differential Revision: D2800746

fb-gh-sync-id: 14cad74f8d761b9b0f07e2827b155cec9ba27f50
parent 138b7236
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <openssl/x509v3.h> #include <openssl/x509v3.h>
#include <folly/Format.h> #include <folly/Format.h>
#include <folly/Memory.h>
#include <folly/SpinLock.h> #include <folly/SpinLock.h>
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
...@@ -43,6 +44,12 @@ std::mutex& initMutex() { ...@@ -43,6 +44,12 @@ std::mutex& initMutex() {
return m; return m;
} }
inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
using EVP_PKEY_deleter =
folly::static_function_deleter<EVP_PKEY, &EVP_PKEY_free>;
} // anonymous namespace } // anonymous namespace
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
...@@ -186,6 +193,32 @@ void SSLContext::loadCertificate(const char* path, const char* format) { ...@@ -186,6 +193,32 @@ void SSLContext::loadCertificate(const char* path, const char* format) {
} }
} }
void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
if (cert.data() == nullptr) {
throw std::invalid_argument("loadCertificate: <cert> is nullptr");
}
std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
if (bio == nullptr) {
throw std::runtime_error("BIO_new: " + getErrors());
}
int written = BIO_write(bio.get(), cert.data(), cert.size());
if (written != cert.size()) {
throw std::runtime_error("BIO_write: " + getErrors());
}
std::unique_ptr<X509, X509_deleter> x509(
PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
if (x509 == nullptr) {
throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
}
if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
}
}
void SSLContext::loadPrivateKey(const char* path, const char* format) { void SSLContext::loadPrivateKey(const char* path, const char* format) {
if (path == nullptr || format == nullptr) { if (path == nullptr || format == nullptr) {
throw std::invalid_argument( throw std::invalid_argument(
...@@ -200,10 +233,35 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) { ...@@ -200,10 +233,35 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) {
} }
} }
void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
if (pkey.data() == nullptr) {
throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
}
std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
if (bio == nullptr) {
throw std::runtime_error("BIO_new: " + getErrors());
}
int written = BIO_write(bio.get(), pkey.data(), pkey.size());
if (written != pkey.size()) {
throw std::runtime_error("BIO_write: " + getErrors());
}
std::unique_ptr<EVP_PKEY, EVP_PKEY_deleter> key(
PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
if (key == nullptr) {
throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
}
if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
}
}
void SSLContext::loadTrustedCertificates(const char* path) { void SSLContext::loadTrustedCertificates(const char* path) {
if (path == nullptr) { if (path == nullptr) {
throw std::invalid_argument( throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
"loadTrustedCertificates: <path> is nullptr");
} }
if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) { if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors()); throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#endif #endif
#include <folly/Random.h> #include <folly/Random.h>
#include <folly/Range.h>
namespace folly { namespace folly {
...@@ -185,6 +186,12 @@ class SSLContext { ...@@ -185,6 +186,12 @@ class SSLContext {
* @param format Certificate file format * @param format Certificate file format
*/ */
virtual void loadCertificate(const char* path, const char* format = "PEM"); virtual void loadCertificate(const char* path, const char* format = "PEM");
/**
* Load server certificate from memory.
*
* @param cert A PEM formatted certificate
*/
virtual void loadCertificateFromBufferPEM(folly::StringPiece cert);
/** /**
* Load private key. * Load private key.
* *
...@@ -192,6 +199,12 @@ class SSLContext { ...@@ -192,6 +199,12 @@ class SSLContext {
* @param format Private key file format * @param format Private key file format
*/ */
virtual void loadPrivateKey(const char* path, const char* format = "PEM"); virtual void loadPrivateKey(const char* path, const char* format = "PEM");
/**
* Load private key from memory.
*
* @param pkey A PEM formatted key
*/
virtual void loadPrivateKeyFromBufferPEM(folly::StringPiece pkey);
/** /**
* Load trusted certificates from specified file. * Load trusted certificates from specified file.
* *
......
...@@ -24,12 +24,14 @@ ...@@ -24,12 +24,14 @@
#include <folly/io/async/test/BlockingSocket.h> #include <folly/io/async/test/BlockingSocket.h>
#include <fstream>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <list> #include <list>
#include <set> #include <set>
#include <unistd.h> #include <unistd.h>
#include <fcntl.h> #include <fcntl.h>
#include <openssl/bio.h>
#include <poll.h> #include <poll.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/socket.h> #include <sys/socket.h>
...@@ -55,8 +57,15 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem"; ...@@ -55,8 +57,15 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
constexpr size_t SSLClient::kMaxReadBufferSz; constexpr size_t SSLClient::kMaxReadBufferSz;
constexpr size_t SSLClient::kMaxReadsPerEvent; constexpr size_t SSLClient::kMaxReadsPerEvent;
TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) : inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
ctx_(new folly::SSLContext), using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
using SSL_deleter = folly::static_function_deleter<SSL, &SSL_free>;
using EVP_PKEY_deleter =
folly::static_function_deleter<EVP_PKEY, &EVP_PKEY_free>;
TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
: ctx_(new folly::SSLContext),
acb_(acb), acb_(acb),
socket_(folly::AsyncServerSocket::newSocket(&evb_)) { socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
// Set up the SSL context // Set up the SSL context
...@@ -144,6 +153,21 @@ bool clientProtoFilterPickNone(unsigned char**, unsigned int*, ...@@ -144,6 +153,21 @@ bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
return false; return false;
} }
std::string getFileAsBuf(const char* fileName) {
std::string buffer;
folly::readFile(fileName, buffer);
return buffer;
}
std::string getCommonName(X509* cert) {
X509_NAME* subject = X509_get_subject_name(cert);
std::string cn;
cn.resize(ub_common_name);
X509_NAME_get_text_by_NID(
subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
return cn;
}
/** /**
* Test connecting to, writing to, reading from, and closing the * Test connecting to, writing to, reading from, and closing the
* connection to the SSL server. * connection to the SSL server.
...@@ -1360,6 +1384,47 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) { ...@@ -1360,6 +1384,47 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
EXPECT_LE(0, server.handshakeTime.count()); EXPECT_LE(0, server.handshakeTime.count());
} }
TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
auto cert = getFileAsBuf(testCert);
auto key = getFileAsBuf(testKey);
std::unique_ptr<BIO, BIO_deleter> certBio(BIO_new(BIO_s_mem()));
BIO_write(certBio.get(), cert.data(), cert.size());
std::unique_ptr<BIO, BIO_deleter> keyBio(BIO_new(BIO_s_mem()));
BIO_write(keyBio.get(), key.data(), key.size());
// Create SSL structs from buffers to get properties
std::unique_ptr<X509, X509_deleter> certStruct(
PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
std::unique_ptr<EVP_PKEY, EVP_PKEY_deleter> keyStruct(
PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
certBio = nullptr;
keyBio = nullptr;
auto origCommonName = getCommonName(certStruct.get());
auto origKeySize = EVP_PKEY_bits(keyStruct.get());
certStruct = nullptr;
keyStruct = nullptr;
auto ctx = std::make_shared<SSLContext>();
ctx->loadPrivateKeyFromBufferPEM(key);
ctx->loadCertificateFromBufferPEM(cert);
ctx->loadTrustedCertificates(testCA);
std::unique_ptr<SSL, SSL_deleter> ssl(ctx->createSSL());
auto newCert = SSL_get_certificate(ssl.get());
auto newKey = SSL_get_privatekey(ssl.get());
// Get properties from SSL struct
auto newCommonName = getCommonName(newCert);
auto newKeySize = EVP_PKEY_bits(newKey);
// Check that the key and cert have the expected properties
EXPECT_EQ(origCommonName, newCommonName);
EXPECT_EQ(origKeySize, newKeySize);
}
TEST(AsyncSSLSocketTest, MinWriteSizeTest) { TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
EventBase eb; EventBase eb;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment