Commit 8ecc23d9 authored by Maxim Georgiev's avatar Maxim Georgiev Committed by Facebook Github Bot

Implementing a callback interface for folly::AsyncSocket allowing to supply an...

Implementing a callback interface for folly::AsyncSocket allowing to supply an ancillary data buffer with msghdr structure to sendmsg() system call

Summary: Implementing a callback interface for folly::AsyncSocket allowing to supply an ancillary data buffer with msghdr structure to sendmsg() system call.

Reviewed By: afrind

Differential Revision: D4422168

fbshipit-source-id: 29a23b05f704aff796d368f4ac9514c49b7ce578
parent d9bf016d
......@@ -34,7 +34,6 @@
#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
#include <folly/portability/OpenSSL.h>
#include <folly/portability/Unistd.h>
using folly::SocketAddress;
using folly::SSLContext;
......@@ -59,6 +58,7 @@ using folly::SSLContext;
using namespace folly::ssl;
using folly::ssl::OpenSSLUtils;
// We have one single dummy SSL context so that we can implement attach
// and detach methods in a thread safe fashion without modifying opnessl.
static SSLContext *dummyCtx = nullptr;
......@@ -1624,7 +1624,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
struct msghdr msg;
struct iovec iov;
int flags = 0;
AsyncSSLSocket* tsslSock;
iov.iov_base = const_cast<char*>(in);
......@@ -1639,23 +1638,28 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
CHECK(tsslSock);
WriteFlags flags = WriteFlags::NONE;
if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags = MSG_EOR;
flags |= WriteFlags::EOR;
}
#ifdef MSG_NOSIGNAL
flags |= MSG_NOSIGNAL;
#endif
#ifdef MSG_MORE
if (tsslSock->corkCurrentWrite_) {
flags |= MSG_MORE;
flags |= WriteFlags::CORK;
}
int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags);
msg.msg_controllen =
tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
msg.msg_controllen);
if (msg.msg_controllen != 0) {
msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
}
#endif
auto result = tsslSock->sendSocketMessage(
OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
BIO_clear_retry_flags(b);
if (!result.exception && result.writeReturn <= 0) {
if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
......
......@@ -185,6 +185,33 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
struct iovec writeOps_[]; ///< write operation(s) list
};
int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
noexcept {
int msg_flags = MSG_DONTWAIT;
#ifdef MSG_NOSIGNAL // Linux-only
msg_flags |= MSG_NOSIGNAL;
#ifdef MSG_MORE
if (isSet(flags, WriteFlags::CORK)) {
// MSG_MORE tells the kernel we have more data to send, so wait for us to
// give it the rest of the data rather than immediately sending a partial
// frame, even when TCP_NODELAY is enabled.
msg_flags |= MSG_MORE;
}
#endif // MSG_MORE
#endif // MSG_NOSIGNAL
if (isSet(flags, WriteFlags::EOR)) {
// marks that this is the last byte of a record (response)
msg_flags |= MSG_EOR;
}
return msg_flags;
}
namespace {
static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
}
AsyncSocket::AsyncSocket()
: eventBase_(nullptr),
writeTimeout_(this, nullptr),
......@@ -254,6 +281,7 @@ void AsyncSocket::init() {
shutdownSocketSet_ = nullptr;
appBytesWritten_ = 0;
appBytesReceived_ = 0;
sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
}
AsyncSocket::~AsyncSocket() {
......@@ -625,6 +653,14 @@ AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
return errMessageCallback_;
}
void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
sendMsgParamCallback_ = callback;
}
AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
return sendMsgParamCallback_;
}
void AsyncSocket::setReadCB(ReadCallback *callback) {
VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", state=" << state_;
......@@ -1363,7 +1399,7 @@ int AsyncSocket::setTCPProfile(int profd) {
}
void AsyncSocket::ioReady(uint16_t events) noexcept {
VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
<< ", events=" << std::hex << events << ", state=" << state_;
DestructorGuard dg(this);
assert(events & EventHandler::READ_WRITE);
......@@ -2023,25 +2059,19 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
msg.msg_namelen = 0;
msg.msg_iov = const_cast<iovec *>(vec);
msg.msg_iovlen = std::min<size_t>(count, kIovMax);
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
msg.msg_controllen);
int msg_flags = MSG_DONTWAIT;
#ifdef MSG_NOSIGNAL // Linux-only
msg_flags |= MSG_NOSIGNAL;
if (isSet(flags, WriteFlags::CORK)) {
// MSG_MORE tells the kernel we have more data to send, so wait for us to
// give it the rest of the data rather than immediately sending a partial
// frame, even when TCP_NODELAY is enabled.
msg_flags |= MSG_MORE;
}
#endif
if (isSet(flags, WriteFlags::EOR)) {
// marks that this is the last byte of a record (response)
msg_flags |= MSG_EOR;
if (msg.msg_controllen != 0) {
msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
} else {
msg.msg_control = nullptr;
}
int msg_flags = sendMsgParamCallback_->getFlags(flags);
auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0) {
......
......@@ -139,6 +139,77 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
};
class SendMsgParamsCallback {
public:
virtual ~SendMsgParamsCallback() = default;
/**
* getFlags() will be invoked to retrieve the desired flags to be passed
* to ::sendmsg() system call. This method was intentionally declared
* non-virtual, so there is no way to override it. Instead feel free to
* override getFlagsImpl(flags, defaultFlags) method instead, and enjoy
* the convenience of defaultFlags passed there.
*
* @param flags Write flags requested for the given write operation
*/
int getFlags(folly::WriteFlags flags) noexcept {
return getFlagsImpl(flags, getDefaultFlags(flags));
}
/**
* getAncillaryData() will be invoked to initialize ancillary data
* buffer referred by "msg_control" field of msghdr structure passed to
* ::sendmsg() system call. The function assumes that the size of buffer
* is not smaller than the value returned by getAncillaryDataSize() method
* for the same combination of flags.
*
* @param flags Write flags requested for the given write operation
* @param data Pointer to ancillary data buffer to initialize.
*/
virtual void getAncillaryData(
folly::WriteFlags /*flags*/,
void* /*data*/) noexcept {}
/**
* getAncillaryDataSize() will be invoked to retrieve the size of
* ancillary data buffer which should be passed to ::sendmsg() system call
*
* @param flags Write flags requested for the given write operation
*/
virtual uint32_t getAncillaryDataSize(folly::WriteFlags /*flags*/)
noexcept {
return 0;
}
static const size_t maxAncillaryDataSize{0x5000};
private:
/**
* getFlagsImpl() will be invoked by getFlags(folly::WriteFlags flags)
* method to retrieve the flags to be passed to ::sendmsg() system call.
* SendMsgParamsCallback::getFlags() is calling this method, and returns
* its results directly to the caller in AsyncSocket.
* Classes inheriting from SendMsgParamsCallback are welcome to override
* this method to force SendMsgParamsCallback to return its own set
* of flags.
*
* @param flags Write flags requested for the given write operation
* @param defaultflags A set of message flags returned by getDefaultFlags()
* method for the given "flags" mask.
*/
virtual int getFlagsImpl(folly::WriteFlags /*flags*/, int defaultFlags) {
return defaultFlags;
}
/**
* getDefaultFlags() will be invoked by getFlags(folly::WriteFlags flags)
* to retrieve the default set of flags, and pass them to getFlagsImpl(...)
*
* @param flags Write flags requested for the given write operation
*/
int getDefaultFlags(folly::WriteFlags flags) noexcept;
};
explicit AsyncSocket();
/**
* Create a new unconnected AsyncSocket.
......@@ -411,6 +482,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
ErrMessageCallback* getErrMessageCallback() const;
/**
* Set a pointer to SendMsgParamsCallback implementation which
* will be used to form ::sendmsg() system call parameters
*
*/
void setSendMsgParamCB(SendMsgParamsCallback* callback);
/**
* Get a pointer to SendMsgParamsCallback implementation currently
* registered with this socket.
*
*/
SendMsgParamsCallback* getSendMsgParamsCB() const;
// Read and write methods
void setReadCB(ReadCallback* callback) override;
ReadCallback* getReadCallback() const override;
......@@ -1010,6 +1095,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
ConnectCallback* connectCallback_; ///< ConnectCallback
ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
SendMsgParamsCallback* ///< Callback for retreaving
sendMsgParamCallback_; ///< ::sendmsg() parameters
ReadCallback* readCallback_; ///< ReadCallback
WriteRequest* writeReqHead_; ///< Chain of WriteRequests
WriteRequest* writeReqTail_; ///< End of WriteRequest chain
......@@ -1022,7 +1109,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
// socket.
std::unique_ptr<IOBuf> preReceivedData_;
int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any
std::chrono::steady_clock::time_point connectStartTime_;
std::chrono::steady_clock::time_point connectEndTime_;
......
......@@ -32,6 +32,7 @@
#include <folly/io/Cursor.h>
#include <openssl/bio.h>
#include <sys/types.h>
#include <sys/utsname.h>
#include <fstream>
#include <iostream>
#include <list>
......@@ -1958,6 +1959,120 @@ TEST(AsyncSSLSocketTest, TestPreReceivedData) {
serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
}
/**
* Test overriding the flags passed to "sendmsg()" system call,
* and verifying that write requests fail properly.
*/
TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
// Start listening on a local port
SendMsgFlagsCallback msgCallback;
ExpectWriteErrorCallback writeCallback(&msgCallback);
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
socket->open();
// Setting flags to "-1" to trigger "Invalid argument" error
// on attempt to use this flags in sendmsg() system call.
msgCallback.resetFlags(-1);
// write()
std::vector<uint8_t> buf(128, 'a');
ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
// close()
socket->close();
cerr << "SendMsgParamsCallback test completed" << endl;
}
#ifdef MSG_ERRQUEUE
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server.
*/
TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
// This test requires Linux kernel v4.6 or later
struct utsname s_uname;
memset(&s_uname, 0, sizeof(s_uname));
ASSERT_EQ(uname(&s_uname), 0);
int major, minor;
folly::StringPiece extra;
if (folly::split<false>(
'.', std::string(s_uname.release) + ".", major, minor, extra)) {
if (major < 4 || (major == 4 && minor < 6)) {
LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
<< "kernel ver. " << s_uname.release << " detected).";
return;
}
}
// Start listening on a local port
SendMsgDataCallback msgCallback;
WriteCheckTimestampCallback writeCallback(&msgCallback);
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
socket->open();
// Adding MSG_EOR flag to the message flags - it'll trigger
// timestamp generation for the last byte of the message.
msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
// Init ancillary data buffer to trigger timestamp notification
union {
uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
struct cmsghdr cmsg;
} u;
u.cmsg.cmsg_level = SOL_SOCKET;
u.cmsg.cmsg_type = SO_TIMESTAMPING;
u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
uint32_t flags =
SOF_TIMESTAMPING_TX_SCHED |
SOF_TIMESTAMPING_TX_SOFTWARE |
SOF_TIMESTAMPING_TX_ACK;
memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
msgCallback.resetData(std::move(ctrl));
// write()
std::vector<uint8_t> buf(128, 'a');
socket->write(buf.data(), buf.size());
// read()
std::vector<uint8_t> readbuf(buf.size());
uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
EXPECT_EQ(bytesRead, buf.size());
EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
writeCallback.checkForTimestampNotifications();
// close()
socket->close();
cerr << "SendMsgDataCallback test completed" << endl;
}
#endif // MSG_ERRQUEUE
#endif
} // namespace
......
......@@ -46,24 +46,106 @@ namespace folly {
// are responsible for setting the succeeded state properly before the
// destructors are called.
class SendMsgParamsCallbackBase :
public folly::AsyncSocket::SendMsgParamsCallback {
public:
SendMsgParamsCallbackBase() {}
void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
oldCallback_ = socket_->getSendMsgParamsCB();
socket_->setSendMsgParamCB(this);
}
int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
override {
return oldCallback_->getFlags(flags);
}
void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
oldCallback_->getAncillaryData(flags, data);
}
uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
return oldCallback_->getAncillaryDataSize(flags);
}
std::shared_ptr<AsyncSSLSocket> socket_;
folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
};
class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
public:
SendMsgFlagsCallback() {}
void resetFlags(int flags) {
flags_ = flags;
}
int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
override {
if (flags_) {
return flags_;
} else {
return oldCallback_->getFlags(flags);
}
}
int flags_{0};
};
class SendMsgDataCallback : public SendMsgFlagsCallback {
public:
SendMsgDataCallback() {}
void resetData(std::vector<char>&& data) {
ancillaryData_.swap(data);
}
void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
if (ancillaryData_.size()) {
std::cerr << "getAncillaryData: copying data" << std::endl;
memcpy(data, ancillaryData_.data(), ancillaryData_.size());
} else {
oldCallback_->getAncillaryData(flags, data);
}
}
uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
if (ancillaryData_.size()) {
std::cerr << "getAncillaryDataSize: returning size" << std::endl;
return ancillaryData_.size();
} else {
return oldCallback_->getAncillaryDataSize(flags);
}
}
std::vector<char> ancillaryData_;
};
class WriteCallbackBase :
public AsyncTransportWrapper::WriteCallback {
public:
WriteCallbackBase()
explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
: state(STATE_WAITING)
, bytesWritten(0)
, exception(AsyncSocketException::UNKNOWN, "none") {}
, exception(AsyncSocketException::UNKNOWN, "none")
, mcb_(mcb) {}
~WriteCallbackBase() {
EXPECT_EQ(STATE_SUCCEEDED, state);
}
void setSocket(
virtual void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
if (mcb_) {
mcb_->setSocket(socket);
}
}
void writeSuccess() noexcept override {
virtual void writeSuccess() noexcept override {
std::cerr << "writeSuccess" << std::endl;
state = STATE_SUCCEEDED;
}
......@@ -84,7 +166,116 @@ public:
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
SendMsgParamsCallbackBase* mcb_;
};
class ExpectWriteErrorCallback :
public WriteCallbackBase {
public:
explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
: WriteCallbackBase(mcb) {}
~ExpectWriteErrorCallback() {
EXPECT_EQ(STATE_FAILED, state);
EXPECT_EQ(exception.type_,
AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
EXPECT_EQ(exception.errno_, 22);
// Suppress the assert in ~WriteCallbackBase()
state = STATE_SUCCEEDED;
}
};
#ifdef MSG_ERRQUEUE
/* copied from include/uapi/linux/net_tstamp.h */
/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
enum SOF_TIMESTAMPING {
SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
SOF_TIMESTAMPING_OPT_ID = (1 << 7),
SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
SOF_TIMESTAMPING_TX_ACK = (1 << 9),
SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
};
class WriteCheckTimestampCallback :
public WriteCallbackBase {
public:
explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
: WriteCallbackBase(mcb) {}
~WriteCheckTimestampCallback() {
EXPECT_EQ(STATE_SUCCEEDED, state);
EXPECT_TRUE(gotTimestamp_);
EXPECT_TRUE(gotByteSeq_);
}
void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) override {
WriteCallbackBase::setSocket(socket);
EXPECT_NE(socket_->getFd(), 0);
int flags = SOF_TIMESTAMPING_OPT_ID
| SOF_TIMESTAMPING_OPT_TSONLY
| SOF_TIMESTAMPING_SOFTWARE;
AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
int ret = tstampingOpt.apply(socket_->getFd(), flags);
EXPECT_EQ(ret, 0);
}
void checkForTimestampNotifications() noexcept {
int fd = socket_->getFd();
std::vector<char> ctrl(1024, 0);
unsigned char data;
struct msghdr msg;
iovec entry;
memset(&msg, 0, sizeof(msg));
entry.iov_base = &data;
entry.iov_len = sizeof(data);
msg.msg_iov = &entry;
msg.msg_iovlen = 1;
msg.msg_control = ctrl.data();
msg.msg_controllen = ctrl.size();
int ret;
while (true) {
ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
if (ret < 0) {
if (errno != EAGAIN) {
auto errnoCopy = errno;
std::cerr << "::recvmsg exited with code " << ret
<< ", errno: " << errnoCopy << std::endl;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
"recvmsg() failed",
errnoCopy);
exception = ex;
}
return;
}
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg != nullptr && cmsg->cmsg_len != 0;
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_TIMESTAMPING) {
gotTimestamp_ = true;
continue;
}
if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
(cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
gotByteSeq_ = true;
continue;
}
}
}
}
bool gotTimestamp_{false};
bool gotByteSeq_{false};
};
#endif // MSG_ERRQUEUE
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
......
......@@ -229,6 +229,64 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
bool gotByteSeq_{false};
};
class TestSendMsgParamsCallback :
public folly::AsyncSocket::SendMsgParamsCallback {
public:
TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
: flags_(flags),
writeFlags_(folly::WriteFlags::NONE),
dataSize_(dataSize),
data_(data),
queriedFlags_(false),
queriedData_(false)
{}
void reset(int flags) {
flags_ = flags;
writeFlags_ = folly::WriteFlags::NONE;
queriedFlags_ = false;
queriedData_ = false;
}
int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
override {
queriedFlags_ = true;
if (writeFlags_ == folly::WriteFlags::NONE) {
writeFlags_ = flags;
} else {
assert(flags == writeFlags_);
}
return flags_;
}
void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
queriedData_ = true;
if (writeFlags_ == folly::WriteFlags::NONE) {
writeFlags_ = flags;
} else {
assert(flags == writeFlags_);
}
assert(data != nullptr);
memcpy(data, data_, dataSize_);
}
uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
if (writeFlags_ == folly::WriteFlags::NONE) {
writeFlags_ = flags;
} else {
assert(flags == writeFlags_);
}
return dataSize_;
}
int flags_;
folly::WriteFlags writeFlags_;
uint32_t dataSize_;
void* data_;
bool queriedFlags_;
bool queriedData_;
};
class TestServer {
public:
// Create a TestServer.
......
......@@ -2823,21 +2823,11 @@ TEST(AsyncSocketTest, EvbCallbacks) {
/* copied from include/uapi/linux/net_tstamp.h */
/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
enum SOF_TIMESTAMPING {
// SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
// SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
// SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
// SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
// SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
// SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
SOF_TIMESTAMPING_OPT_ID = (1 << 7),
SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
// SOF_TIMESTAMPING_TX_ACK = (1 << 9),
SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
// SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
// SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
};
TEST(AsyncSocketTest, ErrMessageCallback) {
TestServer server;
......@@ -3039,3 +3029,167 @@ TEST(AsyncSocket, PreReceivedDataTakeover) {
evb.loop();
}
TEST(AsyncSocketTest, SendMessageFlags) {
TestServer server;
TestSendMsgParamsCallback sendMsgCB(
MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
evb.loop();
ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
// Set SendMsgParamsCallback
socket->setSendMsgParamCB(&sendMsgCB);
ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
// Write the first portion of data. This data is expected to be
// sent out immediately.
std::vector<uint8_t> buf(128, 'a');
WriteCallback wcb;
sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
socket->write(&wcb, buf.data(), buf.size());
ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
ASSERT_TRUE(sendMsgCB.queriedFlags_);
ASSERT_FALSE(sendMsgCB.queriedData_);
// Using different flags for the second write operation.
// MSG_MORE flag is expected to delay sending this
// data to the wire.
sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
socket->write(&wcb, buf.data(), buf.size());
ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
ASSERT_TRUE(sendMsgCB.queriedFlags_);
ASSERT_FALSE(sendMsgCB.queriedData_);
// Make sure the accepted socket saw only the data from
// the first write request.
std::vector<uint8_t> readbuf(2 * buf.size());
uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
ASSERT_EQ(bytesRead, buf.size());
// Make sure the server got a connection and received the data
acceptedSocket->close();
socket->close();
ASSERT_TRUE(socket->isClosedBySelf());
ASSERT_FALSE(socket->isClosedByPeer());
}
TEST(AsyncSocketTest, SendMessageAncillaryData) {
struct sockaddr_un addr = {AF_UNIX,
"AsyncSocketTest.SendMessageAncillaryData\0"};
// Clean up the name in the name space we're going to use
ASSERT_FALSE(remove(addr.sun_path) == -1 && errno != ENOENT);
// Set up listening socket
int lfd = fsp::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(lfd, -1);
ASSERT_NE(bind(lfd, (struct sockaddr*)&addr, sizeof(addr)), -1)
<< "Bind failed: " << errno;
// Create the connecting socket
int csd = fsp::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(csd, -1);
// Listen for incoming connect
ASSERT_NE(listen(lfd, 5), -1);
// Connect to the listening socket
ASSERT_NE(fsp::connect(csd, (struct sockaddr*)&addr, sizeof(addr)), -1)
<< "Connect request failed: " << errno;
// Accept the connection
int sfd = accept(lfd, nullptr, nullptr);
ASSERT_NE(sfd, -1);
// Instantiate AsyncSocket object for the connected socket
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, csd);
// Open a temporary file and write a magic string to it
// We'll transfer the file handle to test the message parameters
// callback logic.
int tmpfd = open("/var/tmp", O_RDWR | O_TMPFILE);
ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
std::string magicString("Magic string");
ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
magicString.length());
// Send message
union {
// Space large enough to hold an 'int'
char control[CMSG_SPACE(sizeof(int))];
struct cmsghdr cmh;
} s_u;
s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
s_u.cmh.cmsg_level = SOL_SOCKET;
s_u.cmh.cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
// Set up the callback providing message parameters
TestSendMsgParamsCallback sendMsgCB(
MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
socket->setSendMsgParamCB(&sendMsgCB);
// We must transmit at least 1 byte of real data in order
// to send ancillary data
int s_data = 12345;
WriteCallback wcb;
socket->write(&wcb, &s_data, sizeof(s_data));
ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
// Receive the message
union {
// Space large enough to hold an 'int'
char control[CMSG_SPACE(sizeof(int))];
struct cmsghdr cmh;
} r_u;
struct msghdr msgh;
struct iovec iov;
int r_data = 0;
msgh.msg_control = r_u.control;
msgh.msg_controllen = sizeof(r_u.control);
msgh.msg_name = nullptr;
msgh.msg_namelen = 0;
msgh.msg_iov = &iov;
msgh.msg_iovlen = 1;
iov.iov_base = &r_data;
iov.iov_len = sizeof(r_data);
// Receive data
ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
// Validate the received message
ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
ASSERT_EQ(r_data, s_data);
int fd = 0;
memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
ASSERT_NE(fd, 0);
std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
// Reposition to the beginning of the file
ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
// Read the magic string back, and compare it with the original
ASSERT_EQ(
magicString.length(),
read(fd, transferredMagicString.data(), transferredMagicString.size()));
ASSERT_TRUE(std::equal(
magicString.begin(),
magicString.end(),
transferredMagicString.begin()));
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment