Commit ca7ce442 authored by Alex Zhu's avatar Alex Zhu Committed by Facebook GitHub Bot

Add AsyncSSLSocket::setSupportedProtocols

Summary: This diff adds AsyncSSLSocket::setSupportedProtocols, analagous to SSL_set_alpn_protos, which allows connection specific ALPNs to be set. Prior to this diff, there was no easy way to change the set of ALPNs to use other than creating a separate SSLContext or manually using the low level OpenSSL interface.

Reviewed By: mingtaoy

Differential Revision: D29247716

fbshipit-source-id: 6f378b4fc75f404e06fe0131ab520b7e4b8f33b6
parent 2f7fdc20
...@@ -437,6 +437,11 @@ std::string AsyncSSLSocket::getApplicationProtocol() const noexcept { ...@@ -437,6 +437,11 @@ std::string AsyncSSLSocket::getApplicationProtocol() const noexcept {
return ""; return "";
} }
void AsyncSSLSocket::setSupportedApplicationProtocols(
const std::vector<std::string>& supportedProtocols) {
encodedAlpn_ = OpenSSLUtils::encodeALPNString(supportedProtocols);
}
void AsyncSSLSocket::setEorTracking(bool track) { void AsyncSSLSocket::setEorTracking(bool track) {
AsyncSocket::setEorTracking(track); AsyncSocket::setEorTracking(track);
} }
...@@ -871,6 +876,19 @@ void AsyncSSLSocket::sslConn( ...@@ -871,6 +876,19 @@ void AsyncSSLSocket::sslConn(
return failHandshake(__func__, *ex); return failHandshake(__func__, *ex);
} }
if (!encodedAlpn_.empty()) {
int result = SSL_set_alpn_protos(
ssl_.get(),
reinterpret_cast<const unsigned char*>(encodedAlpn_.c_str()),
static_cast<unsigned int>(encodedAlpn_.size()));
if (result != 0) {
static const Indestructible<AsyncSocketException> ex(
AsyncSocketException::INTERNAL_ERROR,
"error setting SSL alpn protos");
return failHandshake(__func__, *ex);
}
}
if (!setupSSLBio()) { if (!setupSSLBio()) {
sslState_ = STATE_ERROR; sslState_ = STATE_ERROR;
static const Indestructible<AsyncSocketException> ex( static const Indestructible<AsyncSocketException> ex(
......
...@@ -382,6 +382,8 @@ class AsyncSSLSocket : public AsyncSocket { ...@@ -382,6 +382,8 @@ class AsyncSSLSocket : public AsyncSocket {
bool good() const override; bool good() const override;
bool connecting() const override; bool connecting() const override;
std::string getApplicationProtocol() const noexcept override; std::string getApplicationProtocol() const noexcept override;
void setSupportedApplicationProtocols(
const std::vector<std::string>& supportedProtocols);
std::string getSecurityProtocol() const override { std::string getSecurityProtocol() const override {
if (sslState_ == STATE_UNENCRYPTED) { if (sslState_ == STATE_UNENCRYPTED) {
...@@ -1005,6 +1007,8 @@ class AsyncSSLSocket : public AsyncSocket { ...@@ -1005,6 +1007,8 @@ class AsyncSSLSocket : public AsyncSocket {
std::string sslVerificationAlert_; std::string sslVerificationAlert_;
std::string encodedAlpn_;
bool sessionResumptionAttempted_{false}; bool sessionResumptionAttempted_{false};
// whether the SSL session was resumed using session ID or not // whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false}; bool sessionIDResumed_{false};
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <folly/io/async/AsyncSocketException.h>
#include <folly/io/async/ssl/OpenSSLUtils.h> #include <folly/io/async/ssl/OpenSSLUtils.h>
#include <unordered_map> #include <unordered_map>
...@@ -326,6 +327,26 @@ std::string OpenSSLUtils::getCommonName(X509* x509) { ...@@ -326,6 +327,26 @@ std::string OpenSSLUtils::getCommonName(X509* x509) {
return std::string(buf, length); return std::string(buf, length);
} }
std::string OpenSSLUtils::encodeALPNString(
const std::vector<std::string>& supportedProtocols) {
unsigned int length = 0;
for (const auto& proto : supportedProtocols) {
if (proto.size() > std::numeric_limits<uint8_t>::max()) {
throw std::range_error("ALPN protocol string exceeds maximum length");
}
length += proto.size() + 1;
}
std::string encodedALPN;
encodedALPN.reserve(length);
for (const auto& proto : supportedProtocols) {
encodedALPN.append(1, static_cast<char>(proto.size()));
encodedALPN.append(proto);
}
return encodedALPN;
}
} // namespace ssl } // namespace ssl
} // namespace folly } // namespace folly
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <folly/Range.h> #include <folly/Range.h>
#include <folly/io/async/AsyncSocketException.h>
#include <folly/net/NetworkSocket.h> #include <folly/net/NetworkSocket.h>
#include <folly/portability/OpenSSL.h> #include <folly/portability/OpenSSL.h>
#include <folly/portability/Sockets.h> #include <folly/portability/Sockets.h>
...@@ -125,6 +126,8 @@ class OpenSSLUtils { ...@@ -125,6 +126,8 @@ class OpenSSLUtils {
static void* getBioAppData(BIO* b); static void* getBioAppData(BIO* b);
static NetworkSocket getBioFd(BIO* b); static NetworkSocket getBioFd(BIO* b);
static void setBioFd(BIO* b, NetworkSocket fd, int flags); static void setBioFd(BIO* b, NetworkSocket fd, int flags);
static std::string encodeALPNString(
const std::vector<std::string>& supported_protocols);
}; };
} // namespace ssl } // namespace ssl
......
...@@ -81,4 +81,45 @@ TEST(OpenSSLUtilsTest, getCommonNameEmpty) { ...@@ -81,4 +81,45 @@ TEST(OpenSSLUtilsTest, getCommonNameEmpty) {
EXPECT_EQ(OpenSSLUtils::getCommonName(x509.get()), ""); EXPECT_EQ(OpenSSLUtils::getCommonName(x509.get()), "");
} }
// Tests that encodeALPNString returns a serialized version of the ALPN String
TEST(OpenSSLUtilsTest, encodeALPNString) {
EXPECT_EQ(OpenSSLUtils::encodeALPNString({"rs"}), "\x2rs");
EXPECT_EQ(OpenSSLUtils::encodeALPNString({"rs", "h2"}), "\x2rs\x2h2");
EXPECT_EQ(
OpenSSLUtils::encodeALPNString({"rs", "h2", "spdy/3.1"}),
"\x2rs\x2h2\x8spdy/3.1");
EXPECT_EQ(
OpenSSLUtils::encodeALPNString({"rs", "h2", "spdy/3.1", "http/1.1"}),
"\x2rs\x2h2\x8spdy/3.1\x8http/1.1");
std::string maxSizeProtocolString(std::numeric_limits<uint8_t>::max(), 'p');
EXPECT_EQ(
OpenSSLUtils::encodeALPNString({maxSizeProtocolString}),
"\xFF" + maxSizeProtocolString);
EXPECT_EQ(
OpenSSLUtils::encodeALPNString({maxSizeProtocolString, "rs"}),
"\xFF" + maxSizeProtocolString + "\x2rs");
EXPECT_EQ(
OpenSSLUtils::encodeALPNString({maxSizeProtocolString, "rs", "h2"}),
"\xFF" + maxSizeProtocolString + "\x2rs\x2h2");
EXPECT_EQ(
OpenSSLUtils::encodeALPNString(
{maxSizeProtocolString, "rs", "h2", "spdy/3.1"}),
"\xFF" + maxSizeProtocolString + "\x2rs\x2h2\x8spdy/3.1");
EXPECT_EQ(
OpenSSLUtils::encodeALPNString(
{maxSizeProtocolString, "rs", "h2", "spdy/3.1", "http/1.1"}),
"\xFF" + maxSizeProtocolString + "\x2rs\x2h2\x8spdy/3.1\x8http/1.1");
std::string exceedsMaxSizeProtocolString(
std::numeric_limits<uint8_t>::max() + 1, 'p');
try {
OpenSSLUtils::encodeALPNString({exceedsMaxSizeProtocolString});
} catch (std::range_error const& err) {
EXPECT_EQ(
err.what(), std::string("ALPN protocol string exceeds maximum length"));
}
}
} // namespace folly } // namespace folly
...@@ -757,6 +757,49 @@ TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) { ...@@ -757,6 +757,49 @@ TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
EXPECT_TRUE(!server.serverNameMatch); EXPECT_TRUE(!server.serverNameMatch);
} }
/**
* 1. Create an SSLContext that does not have an ALPN
* 2. Use AsyncSSLSocket::setSupportedApplicationProtocols on the client and
* server, and assert that a common ALPN was negotiated.
*/
TEST(AsyncSSLSocketTest, SetSupportedApplicationProtocols) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
// Use the same SSLContext to continue the handshake after
// tlsext_hostname match.
auto hskServerCtx = std::make_shared<SSLContext>();
const std::string serverExpectedServerName("xyz.newdev.facebook.com");
NetworkSocket fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
std::vector<std::string> protocols;
protocols.push_back("rs");
clientSock->setSupportedApplicationProtocols(protocols);
serverSock->setSupportedApplicationProtocols(protocols);
SNIClient client(std::move(clientSock));
SNIServer server(
std::move(serverSock),
dfServerCtx,
hskServerCtx,
serverExpectedServerName);
eventBase.loop();
EXPECT_TRUE(
client.getApplicationProtocol().compare(
server.getApplicationProtocol()) == 0);
}
#endif #endif
/** /**
* Test SSL client socket * Test SSL client socket
......
...@@ -949,6 +949,10 @@ class SNIClient : private AsyncSSLSocket::HandshakeCB, ...@@ -949,6 +949,10 @@ class SNIClient : private AsyncSSLSocket::HandshakeCB,
socket_->sslConn(this); socket_->sslConn(this);
} }
std::string getApplicationProtocol() {
return socket_->getApplicationProtocol();
}
bool serverNameMatch; bool serverNameMatch;
private: private:
...@@ -986,6 +990,10 @@ class SNIServer : private AsyncSSLSocket::HandshakeCB, ...@@ -986,6 +990,10 @@ class SNIServer : private AsyncSSLSocket::HandshakeCB,
socket_->sslAccept(this); socket_->sslAccept(this);
} }
std::string getApplicationProtocol() {
return socket_->getApplicationProtocol();
}
bool serverNameMatch; bool serverNameMatch;
private: private:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment