Commit 7e6064c6 authored by Alan Frindell's avatar Alan Frindell Committed by facebook-github-bot-4

Add support for ALPN

Summary: With openssl-1.0.2 and later add support for ALPN.  Clients can request NPN only, but the default is to support either (client will send ALPN list, server will send NPN advertisement if ALPN is not negotiated).

Reviewed By: siyengar

Differential Revision: D2710441

fb-gh-sync-id: a8efe69e1869bbecb4ed9e0a513448fcfdb21ca6
parent 7137cffd
......@@ -55,6 +55,7 @@ using folly::AsyncSocket;
using folly::AsyncSocketException;
using folly::AsyncSSLSocket;
using folly::Optional;
using folly::SSLContext;
// We have one single dummy SSL context so that we can implement attach
// and detach methods in a thread safe fashion without modifying opnessl.
......@@ -765,9 +766,11 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
}
}
void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
unsigned* protoLen) const {
if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
void AsyncSSLSocket::getSelectedNextProtocol(
const unsigned char** protoName,
unsigned* protoLen,
SSLContext::NextProtocolType* protoType) const {
if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
"NPN not supported");
}
......@@ -775,11 +778,24 @@ void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
const unsigned char** protoName,
unsigned* protoLen) const {
unsigned* protoLen,
SSLContext::NextProtocolType* protoType) const {
*protoName = nullptr;
*protoLen = 0;
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
SSL_get0_alpn_selected(ssl_, protoName, protoLen);
if (*protoLen > 0) {
if (protoType) {
*protoType = SSLContext::NextProtocolType::ALPN;
}
return true;
}
#endif
#ifdef OPENSSL_NPN_NEGOTIATED
SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
if (protoType) {
*protoType = SSLContext::NextProtocolType::NPN;
}
return true;
#else
return false;
......
......@@ -376,7 +376,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
/**
* Get the name of the protocol selected by the client during
* Next Protocol Negotiation (NPN)
* Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
* (ALPN)
*
* Throw an exception if openssl does not support NPN
*
......@@ -386,13 +387,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* Note: the AsyncSSLSocket retains ownership
* of this string.
* @param protoNameLen Length of the name.
* @param protoType Whether this was an NPN or ALPN negotiation
*/
virtual void getSelectedNextProtocol(const unsigned char** protoName,
unsigned* protoLen) const;
virtual void getSelectedNextProtocol(
const unsigned char** protoName,
unsigned* protoLen,
SSLContext::NextProtocolType* protoType = nullptr) const;
/**
* Get the name of the protocol selected by the client during
* Next Protocol Negotiation (NPN)
* Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
* (ALPN)
*
* @param protoName Name of the protocol (not guaranteed to be
* null terminated); will be set to nullptr if
......@@ -400,10 +405,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* Note: the AsyncSSLSocket retains ownership
* of this string.
* @param protoNameLen Length of the name.
* @param protoType Whether this was an NPN or ALPN negotiation
* @return false if openssl does not support NPN
*/
virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
unsigned* protoLen) const;
virtual bool getSelectedNextProtocolNoThrow(
const unsigned char** protoName,
unsigned* protoLen,
SSLContext::NextProtocolType* protoType = nullptr) const;
/**
* Determine if the session specified during setSSLSession was reused
......
......@@ -305,13 +305,43 @@ void SSLContext::switchCiphersIfTLS11(
}
#endif
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
int SSLContext::alpnSelectCallback(SSL* ssl,
const unsigned char** out,
unsigned char* outlen,
const unsigned char* in,
unsigned int inlen,
void* data) {
SSLContext* context = (SSLContext*)data;
CHECK(context);
if (context->advertisedNextProtocols_.empty()) {
*out = nullptr;
*outlen = 0;
} else {
auto i = context->pickNextProtocols();
const auto& item = context->advertisedNextProtocols_[i];
if (SSL_select_next_proto((unsigned char**)out,
outlen,
item.protocols,
item.length,
in,
inlen) != OPENSSL_NPN_NEGOTIATED) {
return SSL_TLSEXT_ERR_NOACK;
}
}
return SSL_TLSEXT_ERR_OK;
}
#endif
#ifdef OPENSSL_NPN_NEGOTIATED
bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
return setRandomizedAdvertisedNextProtocols({{1, protocols}});
bool SSLContext::setAdvertisedNextProtocols(
const std::list<std::string>& protocols, NextProtocolType protocolType) {
return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
}
bool SSLContext::setRandomizedAdvertisedNextProtocols(
const std::list<NextProtocolsItem>& items) {
const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
unsetNextProtocols();
if (items.size() == 0) {
return false;
......@@ -354,10 +384,20 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
for (auto &advertised_item : advertisedNextProtocols_) {
advertised_item.probability /= total_weight;
}
if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
SSL_CTX_set_next_protos_advertised_cb(
ctx_, advertisedNextProtocolCallback, this);
SSL_CTX_set_next_proto_select_cb(
ctx_, selectNextProtocolCallback, this);
SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
}
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
// Client cannot really use randomized alpn
SSL_CTX_set_alpn_protos(ctx_,
advertisedNextProtocols_[0].protocols,
advertisedNextProtocols_[0].length);
}
#endif
return true;
}
......@@ -372,6 +412,25 @@ void SSLContext::unsetNextProtocols() {
deleteNextProtocolsStrings();
SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
#endif
}
size_t SSLContext::pickNextProtocols() {
unsigned char random_byte;
RAND_bytes(&random_byte, 1);
double random_value = random_byte / 255.0;
double sum = 0;
for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
sum += advertisedNextProtocols_[i].probability;
if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
continue;
}
return i;
}
CHECK(false) << "Failed to pickNextProtocols";
}
int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
......@@ -391,22 +450,11 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
*out = context->advertisedNextProtocols_[selected_index].protocols;
*outlen = context->advertisedNextProtocols_[selected_index].length;
} else {
unsigned char random_byte;
RAND_bytes(&random_byte, 1);
double random_value = random_byte / 255.0;
double sum = 0;
for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
sum += context->advertisedNextProtocols_[i].probability;
if (sum < random_value &&
i + 1 < context->advertisedNextProtocols_.size()) {
continue;
}
auto i = context->pickNextProtocols();
uintptr_t selected = i + 1;
SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
*out = context->advertisedNextProtocols_[i].protocols;
*outlen = context->advertisedNextProtocols_[i].length;
break;
}
}
}
return SSL_TLSEXT_ERR_OK;
......
......@@ -290,33 +290,42 @@ class SSLContext {
*/
void setOptions(long options);
enum class NextProtocolType : uint8_t {
NPN = 0x1,
ALPN = 0x2,
ANY = NPN | ALPN
};
#ifdef OPENSSL_NPN_NEGOTIATED
/**
* Set the list of protocols that this SSL context supports. In server
* mode, this is the list of protocols that will be advertised for Next
* Protocol Negotiation (NPN). In client mode, the first protocol
* advertised by the server that is also on this list is
* chosen. Invoking this function with a list of length zero causes NPN
* to be disabled.
* Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
* (ALPN). In client mode, the first protocol advertised by the server
* that is also on this list is chosen. Invoking this function with a list
* of length zero causes NPN to be disabled.
*
* @param protocols List of protocol names. This method makes a copy,
* so the caller needn't keep the list in scope after
* the call completes. The list must have at least
* one element to enable NPN. Each element must have
* a string length < 256.
* @return true if NPN has been activated. False if NPN is disabled.
* @param protocolType What type of protocol negotiation to support.
* @return true if NPN/ALPN has been activated. False if NPN/ALPN is disabled.
*/
bool setAdvertisedNextProtocols(const std::list<std::string>& protocols);
bool setAdvertisedNextProtocols(
const std::list<std::string>& protocols,
NextProtocolType protocolType = NextProtocolType::ANY);
/**
* Set weighted list of lists of protocols that this SSL context supports.
* In server mode, each element of the list contains a list of protocols that
* could be advertised for Next Protocol Negotiation (NPN). The list of
* protocols that will be advertised to a client is selected randomly, based
* on weights of elements. Client mode doesn't support randomized NPN, so
* this list should contain only 1 element. The first protocol advertised
* by the server that is also on the list of protocols of this element is
* chosen. Invoking this function with a list of length zero causes NPN
* to be disabled.
* could be advertised for Next Protocol Negotiation (NPN) or Application
* Layer Protocol Negotiation (ALPN). The list of protocols that will be
* advertised to a client is selected randomly, based on weights of elements.
* Client mode doesn't support randomized NPN/ALPN, so this list should
* contain only 1 element. The first protocol advertised by the server that
* is also on the list of protocols of this element is chosen. Invoking this
* function with a list of length zero causes NPN/ALPN to be disabled.
*
* @param items List of NextProtocolsItems, Each item contains a list of
* protocol names and weight. After the call of this fucntion
......@@ -326,10 +335,12 @@ class SSLContext {
* completes. The list must have at least one element with
* non-zero weight and non-empty protocols list to enable NPN.
* Each name of the protocol must have a string length < 256.
* @return true if NPN has been activated. False if NPN is disabled.
* @param protocolType What type of protocol negotiation to support.
* @return true if NPN/ALPN has been activated. False if NPN/ALPN is disabled.
*/
bool setRandomizedAdvertisedNextProtocols(
const std::list<NextProtocolsItem>& items);
const std::list<NextProtocolsItem>& items,
NextProtocolType protocolType = NextProtocolType::ANY);
void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) {
clientProtoFilter_ = cb;
......@@ -459,6 +470,16 @@ class SSLContext {
SSL* ssl, unsigned char **out, unsigned char *outlen,
const unsigned char *server, unsigned int server_len, void *args);
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
static int alpnSelectCallback(SSL* ssl,
const unsigned char** out,
unsigned char* outlen,
const unsigned char* in,
unsigned int inlen,
void* data);
#endif
size_t pickNextProtocols();
#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
// This class contains all allowed ciphers for SSL false start. Call its
......
This diff is collapsed.
......@@ -804,10 +804,12 @@ class NpnClient :
const unsigned char* nextProto;
unsigned nextProtoLength;
SSLContext::NextProtocolType protocolType;
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
socket_->getSelectedNextProtocol(&nextProto,
&nextProtoLength);
socket_->getSelectedNextProtocol(
&nextProto, &nextProtoLength, &protocolType);
}
void handshakeErr(
AsyncSSLSocket*,
......@@ -838,10 +840,12 @@ class NpnServer :
const unsigned char* nextProto;
unsigned nextProtoLength;
SSLContext::NextProtocolType protocolType;
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
socket_->getSelectedNextProtocol(&nextProto,
&nextProtoLength);
socket_->getSelectedNextProtocol(
&nextProto, &nextProtoLength, &protocolType);
}
void handshakeErr(
AsyncSSLSocket*,
......
......@@ -42,12 +42,14 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
MOCK_CONST_METHOD0(good, bool());
MOCK_CONST_METHOD0(readable, bool());
MOCK_CONST_METHOD0(hangup, bool());
MOCK_CONST_METHOD2(
getSelectedNextProtocol,
void(const unsigned char**, unsigned*));
MOCK_CONST_METHOD2(
getSelectedNextProtocolNoThrow,
bool(const unsigned char**, unsigned*));
MOCK_CONST_METHOD3(getSelectedNextProtocol,
void(const unsigned char**,
unsigned*,
SSLContext::NextProtocolType*));
MOCK_CONST_METHOD3(getSelectedNextProtocolNoThrow,
bool(const unsigned char**,
unsigned*,
SSLContext::NextProtocolType*));
MOCK_METHOD1(setPeek, void(bool));
MOCK_METHOD1(setReadCB, void(ReadCallback*));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment