Commit 4c2bc928 authored by Brandon Schlinker's avatar Brandon Schlinker Committed by Facebook GitHub Bot

Unify socket message generation, sendSocketMessage

Summary: Both `AsyncSocket` and `AsyncSSLSocket` currently have code to generate a socket message and control messages with ancillary data. Merge this code into a new function in `AsyncSocket`.

Differential Revision: D24096351

fbshipit-source-id: 87d90648c10c87832f868e322acf59c97b8ac8b7
parent bb168bd8
......@@ -1756,27 +1756,10 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
flags = unSet(flags, folly::WriteFlags::EOR);
}
struct msghdr msg;
struct iovec iov;
iov.iov_base = const_cast<char*>(in);
iov.iov_len = size_t(inl);
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
int msg_flags =
sslSock->getSendMsgParamsCB()->getFlags(flags, false /*zeroCopyEnabled*/);
msg.msg_controllen =
sslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
CHECK_GE(
AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
msg.msg_controllen);
if (msg.msg_controllen != 0) {
msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
sslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
}
auto result =
sslSock->sendSocketMessage(OpenSSLUtils::getBioFd(b), &msg, msg_flags);
struct iovec vec;
vec.iov_base = const_cast<char*>(in);
vec.iov_len = size_t(inl);
auto result = sslSock->sendSocketMessage(&vec, 1, flags);
BIO_clear_retry_flags(b);
if (!result.exception && result.writeReturn <= 0) {
if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
......
......@@ -2542,6 +2542,40 @@ ssize_t AsyncSocket::tfoSendMsg(
return detail::tfo_sendmsg(fd, msg, msg_flags);
}
AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
const iovec* vec, size_t count, WriteFlags flags) {
struct msghdr msg = {};
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = const_cast<struct iovec*>(vec);
msg.msg_iovlen = std::min<size_t>(count, kIovMax);
msg.msg_flags = 0; // ignored, must forward flags via sendmsg parameter
msg.msg_control = nullptr;
msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
CHECK_GE(
AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
msg.msg_controllen);
if (msg.msg_controllen != 0) {
msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
}
int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
if (writeResult.writeReturn < 0 && zeroCopyEnabled_ && errno == ENOBUFS) {
// workaround for running with zerocopy enabled but without a big enough
// memlock value - see ulimit -l
zeroCopyEnabled_ = false;
zeroCopyReenableCounter_ = zeroCopyReenableThreshold_;
msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
writeResult = sendSocketMessage(fd_, &msg, msg_flags);
}
return writeResult;
}
AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
NetworkSocket fd, struct msghdr* msg, int msg_flags) {
ssize_t totalWritten = 0;
......@@ -2615,40 +2649,8 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
// We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
// We correctly handle EPIPE errors, so we never want to receive SIGPIPE
// (since it may terminate the program if the main program doesn't explicitly
// ignore it).
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = const_cast<iovec*>(vec);
msg.msg_iovlen = std::min<size_t>(count, kIovMax);
msg.msg_flags = 0;
msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
CHECK_GE(
AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
msg.msg_controllen);
if (msg.msg_controllen != 0) {
msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
} else {
msg.msg_control = nullptr;
}
int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
auto writeResult = sendSocketMessage(vec, count, flags);
auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0 && zeroCopyEnabled_ && errno == ENOBUFS) {
// workaround for running with zerocopy enabled but without a big enough
// memlock value - see ulimit -l
zeroCopyEnabled_ = false;
zeroCopyReenableCounter_ = zeroCopyReenableThreshold_;
msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
writeResult = sendSocketMessage(fd_, &msg, msg_flags);
totalWritten = writeResult.writeReturn;
}
if (totalWritten < 0) {
bool tryAgain = (errno == EAGAIN);
#ifdef __APPLE__
......
......@@ -1282,6 +1282,16 @@ class AsyncSocket : public AsyncTransport {
uint32_t* countWritten,
uint32_t* partialWritten);
/**
* Prepares a msghdr and sends the message over the socket using sendmsg
*
* @param vec The iovec array pointing to the buffers to write.
* @param count The length of the iovec array.
* @param flags Set of write flags.
*/
virtual AsyncSocket::WriteResult sendSocketMessage(
const iovec* vec, size_t count, WriteFlags flags);
/**
* Sends the message over the socket using sendmsg
*
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment