Commit 14a19db2 authored by Ranjeeth Dasineni's avatar Ranjeeth Dasineni Committed by facebook-github-bot-9

add callback to specify a client next protocol filter

Summary: From the client perspective, we set the list in order of
preference once and call into openssl to do the selection. This adds
a little more flexibility in that client optionally can customize the
selection for each negotiation. added tests for the no-op case and the
customized case. Feel free to suggest improvements.

Reviewed By: @afrind

Differential Revision: D2489142
parent 1f46d8c5
...@@ -513,13 +513,21 @@ int SSLContext::selectNextProtocolCallback( ...@@ -513,13 +513,21 @@ int SSLContext::selectNextProtocolCallback(
} }
unsigned char *client; unsigned char *client;
int client_len; unsigned int client_len;
if (ctx->advertisedNextProtocols_.empty()) { bool filtered = false;
client = (unsigned char *) ""; auto cpf = ctx->getClientProtocolFilterCallback();
client_len = 0; if (cpf) {
} else { filtered = (*cpf)(&client, &client_len, server, server_len);
client = ctx->advertisedNextProtocols_[0].protocols; }
client_len = ctx->advertisedNextProtocols_[0].length;
if (!filtered) {
if (ctx->advertisedNextProtocols_.empty()) {
client = (unsigned char *) "";
client_len = 0;
} else {
client = ctx->advertisedNextProtocols_[0].protocols;
client_len = ctx->advertisedNextProtocols_[0].length;
}
} }
int retval = SSL_select_next_proto(out, outlen, server, server_len, int retval = SSL_select_next_proto(out, outlen, server, server_len,
......
...@@ -93,6 +93,10 @@ class SSLContext { ...@@ -93,6 +93,10 @@ class SSLContext {
double probability; double probability;
}; };
// Function that selects a client protocol given the server's list
using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
const unsigned char*, unsigned int);
/** /**
* Convenience function to call getErrors() with the current errno value. * Convenience function to call getErrors() with the current errno value.
* *
...@@ -327,6 +331,13 @@ class SSLContext { ...@@ -327,6 +331,13 @@ class SSLContext {
bool setRandomizedAdvertisedNextProtocols( bool setRandomizedAdvertisedNextProtocols(
const std::list<NextProtocolsItem>& items); const std::list<NextProtocolsItem>& items);
void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) {
clientProtoFilter_ = cb;
}
ClientProtocolFilterCallback getClientProtocolFilterCallback() {
return clientProtoFilter_;
}
/** /**
* Disables NPN on this SSL context. * Disables NPN on this SSL context.
*/ */
...@@ -431,6 +442,8 @@ class SSLContext { ...@@ -431,6 +442,8 @@ class SSLContext {
std::vector<ClientHelloCallback> clientHelloCbs_; std::vector<ClientHelloCallback> clientHelloCbs_;
#endif #endif
ClientProtocolFilterCallback clientProtoFilter_{nullptr};
static bool initialized_; static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
......
...@@ -127,6 +127,21 @@ void sslsocketpair( ...@@ -127,6 +127,21 @@ void sslsocketpair(
// (*serverSock)->setSendTimeout(100); // (*serverSock)->setSendTimeout(100);
} }
// client protocol filters
bool clientProtoFilterPickPony(unsigned char** client,
unsigned int* client_len, const unsigned char*, unsigned int ) {
//the protocol string in length prefixed byte string. the
//length byte is not included in the length
static unsigned char p[7] = {6,'p','o','n','i','e','s'};
*client = p;
*client_len = 7;
return true;
}
bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
const unsigned char*, unsigned int) {
return false;
}
/** /**
* Test connecting to, writing to, reading from, and closing the * Test connecting to, writing to, reading from, and closing the
...@@ -387,6 +402,64 @@ TEST(AsyncSSLSocketTest, NpnTestNoOverlap) { ...@@ -387,6 +402,64 @@ TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
EXPECT_EQ(selected.compare("blub"), 0); EXPECT_EQ(selected.compare("blub"), 0);
} }
TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub"});
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
EXPECT_EQ(selected.compare("ponies"), 0);
}
TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub"});
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
EXPECT_EQ(selected.compare("blub"), 0);
}
TEST(AsyncSSLSocketTest, RandomizedNpnTest) { TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
// Probability that this test will fail is 2^-64, which could be considered // Probability that this test will fail is 2^-64, which could be considered
// as negligible. // as negligible.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment