Commit 8aac0e33 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 7

Always override write bio method

Summary:
Always overriding write bio method
allows us to more cleanly implement
features like eor tracking, support
multiple ssl libraries, and also TFO

Reviewed By: anirudhvr

Differential Revision: D3350482

fbshipit-source-id: ddd2333431f9d636d69c8325b2c18d7cc043b848
parent 9fc18f0a
......@@ -223,16 +223,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
}
BIO_METHOD eorAwareBioMethod;
BIO_METHOD sslWriteBioMethod;
void* initEorBioMethod(void) {
memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
void* initsslWriteBioMethod(void) {
memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
// override the bwrite method for MSG_EOR support
eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
sslWriteBioMethod.bwrite = AsyncSSLSocket::bioWrite;
// Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
// Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
// set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
// then have specific handlings. The eorAwareBioWrite should be compatible
// then have specific handlings. The sslWriteBioWrite should be compatible
// with the one in openssl.
// Return something here to enable AsyncSSLSocket to call this method using
......@@ -314,8 +314,8 @@ AsyncSSLSocket::~AsyncSSLSocket() {
void AsyncSSLSocket::init() {
// Do this here to ensure we initialize this once before any use of
// AsyncSSLSocket instances and not as part of library load.
static const auto eorAwareBioMethodInitializer = initEorBioMethod();
(void)eorAwareBioMethodInitializer;
static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
(void)sslWriteBioMethodInitializer;
setup_SSL_CTX(ctx_->getSSLCtx());
}
......@@ -401,37 +401,15 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
}
bool AsyncSSLSocket::isEorTrackingEnabled() const {
if (ssl_ == nullptr) {
return false;
}
const BIO *wb = SSL_get_wbio(ssl_);
return wb && wb->method == &eorAwareBioMethod;
return trackEor_;
}
void AsyncSSLSocket::setEorTracking(bool track) {
BIO *wb = SSL_get_wbio(ssl_);
if (!wb) {
throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
"setting EOR tracking without an initialized "
"BIO");
}
if (track) {
if (wb->method != &eorAwareBioMethod) {
// only do this if we didn't
wb->method = &eorAwareBioMethod;
BIO_set_app_data(wb, this);
if (trackEor_ != track) {
trackEor_ = track;
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
}
} else if (wb->method == &eorAwareBioMethod) {
wb->method = BIO_s_socket();
BIO_set_app_data(wb, nullptr);
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
} else {
CHECK(wb->method == BIO_s_socket());
}
}
size_t AsyncSSLSocket::getRawBytesWritten() const {
......@@ -703,6 +681,19 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
}
}
bool AsyncSSLSocket::setupSSLBio() {
auto wb = BIO_new(&sslWriteBioMethod);
if (!wb) {
return false;
}
BIO_set_app_data(wb, this);
BIO_set_fd(wb, fd_, BIO_NOCLOSE);
SSL_set_bio(ssl_, wb, wb);
return true;
}
void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this);
......@@ -741,9 +732,15 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
return failHandshake(__func__, ex);
}
if (!setupSSLBio()) {
sslState_ = STATE_ERROR;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
return failHandshake(__func__, ex);
}
applyVerificationOptions(ssl_);
SSL_set_fd(ssl_, fd_);
if (sslSession_ != nullptr) {
SSL_set_session(ssl_, sslSession_);
SSL_SESSION_free(sslSession_);
......@@ -1010,7 +1007,14 @@ AsyncSSLSocket::handleAccept() noexcept {
<< ", fd=" << fd_ << "): " << e.what();
return failHandshake(__func__, ex);
}
SSL_set_fd(ssl_, fd_);
if (!setupSSLBio()) {
sslState_ = STATE_ERROR;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
return failHandshake(__func__, ex);
}
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
applyVerificationOptions(ssl_);
......@@ -1448,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
bool eor) {
if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
if (eor && trackEor_) {
if (appEorByteNo_) {
// cannot track for more than one app byte EOR
CHECK(appEorByteNo_ == appBytesWritten_ + n);
......@@ -1493,34 +1497,37 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
}
}
int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
int ret;
struct msghdr msg;
struct iovec iov;
int flags = 0;
AsyncSSLSocket *tsslSock;
AsyncSSLSocket* tsslSock;
iov.iov_base = const_cast<char *>(in);
iov.iov_base = const_cast<char*>(in);
iov.iov_len = inl;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
tsslSock =
reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
if (tsslSock &&
tsslSock->minEorRawByteNo_ &&
auto appData = BIO_get_app_data(b);
CHECK(appData);
tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
CHECK(tsslSock);
if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags = MSG_EOR;
}
ret = sendmsg(b->num, &msg, flags);
ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
BIO_clear_retry_flags(b);
if (ret <= 0) {
if (BIO_sock_should_retry(ret))
BIO_set_retry_write(b);
}
return(ret);
return ret;
}
int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
......
......@@ -652,7 +652,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
static int getSSLExDataIndex();
static AsyncSSLSocket* getFromSSL(const SSL *ssl);
static int eorAwareBioWrite(BIO *b, const char *in, int inl);
static int bioWrite(BIO* b, const char* in, int inl);
void resetClientHelloParsing(SSL *ssl);
static void clientHelloParsingCallback(int write_p, int version,
int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
......@@ -774,6 +774,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
void applyVerificationOptions(SSL * ssl);
/**
* Sets up SSL with a custom write bio which intercepts all writes.
*
* @return true, if succeeds and false if there is an error creating the bio.
*/
bool setupSSLBio();
/**
* A SSL_write wrapper that understand EOR
*
......@@ -815,6 +822,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false};
// Whether to track EOR or not.
bool trackEor_{false};
// The app byte num that we are tracking for the MSG_EOR
// Only one app EOR byte can be tracked.
size_t appEorByteNo_{0};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment