Commit 889fe563 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 7

Add TFO support to AsyncSSLSocket

Summary:
This adds TFO support to AsyncSSLSocket which
uses the support for TFO from AsyncSocket.

Because of the way AsyncSSLSocket inherits from
AsyncSocket it is tricky.

The following changes were made:
1. Openssl internally will treat only errors with return
code -1 as READ_REQUIRED or WRITE_REQUIRED errors. So this
diff changes the return value of the errors in the TFO fallback
cases to -1.

2. In case we fallback after SSL_connect() to a normal connect,
we would have to restart the connection process after connect
succeeds. To do this this overrides the connection success callback
and restarts the connection before sending the callback to AsyncSocket
because sometimes callbacks might synchronously call sslConn() in the
normal connect cases.

3. Delegated bioWrite to call sendSocketMessage instead of sendmsg directly.

Reviewed By: djwatson

Differential Revision: D3391735

fbshipit-source-id: 61434f6de4a9c3d03973c9ab9e51eb49e751e5cf
parent 0c620c8f
......@@ -1084,8 +1084,9 @@ AsyncSSLSocket::handleConnect() noexcept {
return AsyncSocket::handleConnect();
}
assert(state_ == StateEnum::ESTABLISHED &&
sslState_ == STATE_CONNECTING);
assert(
(state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
sslState_ == STATE_CONNECTING);
assert(ssl_);
int ret = SSL_connect(ssl_);
......@@ -1138,6 +1139,16 @@ AsyncSSLSocket::handleConnect() noexcept {
AsyncSocket::handleInitialReadWrite();
}
void AsyncSSLSocket::invokeConnectSuccess() {
if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
// If we failed TFO, we'd fall back to trying to connect the socket,
// when we succeed we should handle the writes that caused us to start
// TFO.
handleWrite();
}
AsyncSocket::invokeConnectSuccess();
}
void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
// turn on the buffer movable in openssl
......@@ -1498,7 +1509,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
}
int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
int ret;
struct msghdr msg;
struct iovec iov;
int flags = 0;
......@@ -1521,17 +1531,20 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
flags = MSG_EOR;
}
ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
auto result =
tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags);
BIO_clear_retry_flags(b);
if (ret <= 0) {
if (BIO_sock_should_retry(ret))
if (!result.exception && result.writeReturn <= 0) {
if (BIO_sock_should_retry(result.writeReturn)) {
BIO_set_retry_write(b);
}
}
return ret;
return result.writeReturn;
}
int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
X509_STORE_CTX* x509Ctx) {
int AsyncSSLSocket::sslVerifyCallback(
int preverifyOk,
X509_STORE_CTX* x509Ctx) {
SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
......
......@@ -798,6 +798,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
void invokeHandshakeErr(const AsyncSocketException& ex);
void invokeHandshakeCB();
void invokeConnectSuccess() override;
void cacheLocalPeerAddr();
static void sslInfoCallback(const SSL *ssl, int type, int val);
......
......@@ -1752,9 +1752,8 @@ ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
return detail::tfo_sendmsg(fd, msg, msg_flags);
}
AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
struct msghdr* msg,
int msg_flags) {
AsyncSocket::WriteResult
AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
ssize_t totalWritten = 0;
if (state_ == StateEnum::FAST_OPEN) {
sockaddr_storage addr;
......@@ -1778,11 +1777,9 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
return WriteResult(
WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
}
// Let's fake it that no bytes were written.
// Some clients check errno even if return code is 0, so we
// set it just in case.
// Let's fake it that no bytes were written and return an errno.
errno = EAGAIN;
totalWritten = 0;
totalWritten = -1;
} else if (errno == EOPNOTSUPP) {
VLOG(4) << "TFO not supported";
// Try falling back to connecting.
......@@ -1797,10 +1794,8 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
}
// If there was no exception during connections,
// we would return that no bytes were written.
// Some clients check errno even if return code is 0, so we
// set it just in case.
errno = EAGAIN;
totalWritten = 0;
totalWritten = -1;
} catch (const AsyncSocketException& ex) {
return WriteResult(
WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
......@@ -1816,7 +1811,7 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
AsyncSocketException::UNKNOWN, "No more free local ports"));
}
} else {
totalWritten = ::sendmsg(fd_, msg, msg_flags);
totalWritten = ::sendmsg(fd, msg, msg_flags);
}
return WriteResult(totalWritten);
}
......@@ -1855,7 +1850,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
// marks that this is the last byte of a record (response)
msg_flags |= MSG_EOR;
}
auto writeResult = sendSocketMessage(&msg, msg_flags);
auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0) {
if (!writeResult.exception && errno == EAGAIN) {
......
......@@ -817,7 +817,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param msg Message to send
* @param msg_flags Flags to pass to sendmsg
*/
AsyncSocket::WriteResult sendSocketMessage(struct msghdr* msg, int msg_flags);
AsyncSocket::WriteResult
sendSocketMessage(int fd, struct msghdr* msg, int msg_flags);
virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
......@@ -855,7 +856,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
void failWrite(const char* fn, const AsyncSocketException& ex);
void failAllWrites(const AsyncSocketException& ex);
void invokeConnectErr(const AsyncSocketException& ex);
void invokeConnectSuccess();
virtual void invokeConnectSuccess();
void invalidState(ConnectCallback* callback);
void invalidState(ReadCallback* callback);
void invalidState(WriteCallback* callback);
......
......@@ -15,26 +15,28 @@
*/
#include <folly/io/async/test/AsyncSSLSocketTest.h>
#include <signal.h>
#include <pthread.h>
#include <signal.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/SocketAddress.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <fstream>
#include <fcntl.h>
#include <folly/io/Cursor.h>
#include <gtest/gtest.h>
#include <openssl/bio.h>
#include <sys/types.h>
#include <fstream>
#include <iostream>
#include <list>
#include <set>
#include <fcntl.h>
#include <openssl/bio.h>
#include <sys/types.h>
#include <folly/io/Cursor.h>
#include <gmock/gmock.h>
using std::string;
using std::vector;
......@@ -43,6 +45,8 @@ using std::cerr;
using std::endl;
using std::list;
using namespace testing;
namespace folly {
uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
......@@ -55,7 +59,7 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
constexpr size_t SSLClient::kMaxReadBufferSz;
constexpr size_t SSLClient::kMaxReadsPerEvent;
TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
: ctx_(new folly::SSLContext),
acb_(acb),
socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
......@@ -67,7 +71,13 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
acb_->ctx_ = ctx_;
acb_->base_ = &evb_;
//set up the listening socket
// Enable TFO
if (enableTFO) {
LOG(INFO) << "server TFO enabled";
socket_->setTFOEnabled(true, 1000);
}
// set up the listening socket
socket_->bind(0);
socket_->getAddress(&address_);
socket_->listen(100);
......@@ -1674,6 +1684,203 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
std::string::npos);
}
#if FOLLY_ALLOW_TFO
class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
public:
using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
explicit MockAsyncTFOSSLSocket(
std::shared_ptr<folly::SSLContext> sslCtx,
EventBase* evb)
: AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
};
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server with TFO.
*/
TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, true);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
// connect
auto socket =
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO();
socket->open();
// write()
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
socket->write(buf.data(), buf.size());
// read()
std::array<uint8_t, 128> readbuf;
uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
EXPECT_EQ(bytesRead, 128);
EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
// close()
socket->close();
}
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server with TFO.
*/
TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, false);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
// connect
auto socket =
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO();
socket->open();
// write()
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
socket->write(buf.data(), buf.size());
// read()
std::array<uint8_t, 128> readbuf;
uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
EXPECT_EQ(bytesRead, 128);
EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
// close()
socket->close();
}
class ConnCallback : public AsyncSocket::ConnectCallback {
public:
virtual void connectSuccess() noexcept override {
state = State::SUCCESS;
}
virtual void connectErr(const AsyncSocketException&) noexcept override {
state = State::ERROR;
}
enum class State { WAITING, SUCCESS, ERROR };
State state{State::WAITING};
};
MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
EventBase* evb,
const SocketAddress& address) {
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
// connect
auto socket = MockAsyncTFOSSLSocket::UniquePtr(
new MockAsyncTFOSSLSocket(sslContext, evb));
socket->enableTFO();
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = address.getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len);
}));
return socket;
}
TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, true);
EventBase evb;
auto socket = setupSocketWithFallback(&evb, server.getAddress());
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
evb.loop();
EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
evb.runInEventBaseThread([&] { socket->detachEventBase(); });
evb.loop();
BlockingSocket sock(std::move(socket));
// write()
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
sock.write(buf.data(), buf.size());
// read()
std::array<uint8_t, 128> readbuf;
uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
EXPECT_EQ(bytesRead, 128);
EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
// close()
sock.close();
}
TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, true);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
// connect
auto socket =
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO();
EXPECT_THROW(
socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
}
TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, true);
EventBase evb;
auto socket = setupSocketWithFallback(&evb, server.getAddress());
ConnCallback ccb;
// Set a short timeout
socket->connect(&ccb, server.getAddress(), 1);
evb.loop();
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
}
#endif
} // namespace
///////////////////////////////////////////////////////////////////////////
......
......@@ -607,7 +607,9 @@ class TestSSLServer {
public:
// Create a TestSSLServer.
// This immediately starts listening on the given port.
explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
explicit TestSSLServer(
SSLServerAcceptCallbackBase* acb,
bool enableTFO = false);
// Kill the thread.
~TestSSLServer() {
......
......@@ -16,36 +16,41 @@
#pragma once
#include <folly/Optional.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/SSLContext.h>
class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
public folly::AsyncTransportWrapper::ReadCallback,
public folly::AsyncTransportWrapper::WriteCallback
{
public folly::AsyncTransportWrapper::WriteCallback {
public:
explicit BlockingSocket(int fd)
: sock_(new folly::AsyncSocket(&eventBase_, fd)) {
}
: sock_(new folly::AsyncSocket(&eventBase_, fd)) {}
BlockingSocket(folly::SocketAddress address,
std::shared_ptr<folly::SSLContext> sslContext)
: sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
new folly::AsyncSocket(&eventBase_)),
address_(address) {}
BlockingSocket(
folly::SocketAddress address,
std::shared_ptr<folly::SSLContext> sslContext)
: sock_(
sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_)
: new folly::AsyncSocket(&eventBase_)),
address_(address) {}
explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
: sock_(std::move(socket)) {
sock_->attachEventBase(&eventBase_);
}
void enableTFO() {
sock_->enableTFO();
}
void setAddress(folly::SocketAddress address) {
address_ = address;
}
void open() {
sock_->connect(this, address_);
void open(
std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) {
sock_->connect(this, address_, timeout.count());
eventBase_.loop();
if (err_.hasValue()) {
throw err_.value();
......@@ -54,7 +59,9 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
void close() {
sock_->close();
}
void closeWithReset() { sock_->closeWithReset(); }
void closeWithReset() {
sock_->closeWithReset();
}
int32_t write(uint8_t const* buf, size_t len) {
sock_->write(this, buf, len);
......@@ -67,11 +74,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
void flush() {}
int32_t readAll(uint8_t *buf, size_t len) {
int32_t readAll(uint8_t* buf, size_t len) {
return readHelper(buf, len, true);
}
int32_t read(uint8_t *buf, size_t len) {
int32_t read(uint8_t* buf, size_t len) {
return readHelper(buf, len, false);
}
......@@ -83,7 +90,7 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
folly::EventBase eventBase_;
folly::AsyncSocket::UniquePtr sock_;
folly::Optional<folly::AsyncSocketException> err_;
uint8_t *readBuf_{nullptr};
uint8_t* readBuf_{nullptr};
size_t readLen_{0};
folly::SocketAddress address_;
......@@ -102,18 +109,18 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
sock_->setReadCB(nullptr);
}
}
void readEOF() noexcept override {
}
void readEOF() noexcept override {}
void readErr(const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
void writeSuccess() noexcept override {}
void writeErr(size_t /* bytesWritten */,
const folly::AsyncSocketException& ex) noexcept override {
void writeErr(
size_t /* bytesWritten */,
const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
int32_t readHelper(uint8_t *buf, size_t len, bool all) {
int32_t readHelper(uint8_t* buf, size_t len, bool all) {
if (!sock_->good()) {
return 0;
}
......@@ -132,8 +139,8 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
throw err_.value();
}
if (all && readLen_ > 0) {
throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
"eof");
throw folly::AsyncSocketException(
folly::AsyncSocketException::UNKNOWN, "eof");
}
return len - readLen_;
}
......
......@@ -24,6 +24,7 @@ DEFINE_string(host, "localhost", "Host");
DEFINE_int32(port, 0, "port");
DEFINE_bool(tfo, false, "enable tfo");
DEFINE_string(msg, "", "Message to send");
DEFINE_bool(ssl, false, "use ssl");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
......@@ -35,7 +36,13 @@ int main(int argc, char** argv) {
// Prep the socket
EventBase evb;
AsyncSocket::UniquePtr socket(new AsyncSocket(&evb));
AsyncSocket::UniquePtr socket;
if (FLAGS_ssl) {
auto sslContext = std::make_shared<SSLContext>();
socket = AsyncSocket::UniquePtr(new AsyncSSLSocket(sslContext, &evb));
} else {
socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
}
socket->detachEventBase();
if (FLAGS_tfo) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment