Commit e31eb32a authored by Junqi Wang's avatar Junqi Wang Committed by Facebook GitHub Bot

Allow connect before bind

Summary: connect will automatically bind the socket if the socket is not bound yet

Reviewed By: yangchi

Differential Revision: D21845740

fbshipit-source-id: 27a5b44476dfc0b2ae5ff2f0a6c1bd4e976eadc9
parent d11cbbc9
......@@ -415,7 +415,7 @@ void testAsyncUDPRecvmsg(bool useRegisteredFds) {
serverSocketVec.emplace_back(std::move(serverSock));
// connect the client
CHECK_EQ(clientSock->connect(addr), 0);
clientSock->connect(addr);
for (size_t j = 0; j < kNumPackets; j++) {
auto buf = folly::IOBuf::copyBuffer(data.c_str(), data.size());
CHECK_EQ(clientSock->write(addr, std::move(buf)), data.size());
......
......@@ -58,11 +58,9 @@ AsyncUDPSocket::~AsyncUDPSocket() {
}
}
void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
NetworkSocket socket = netops::socket(
address.getFamily(),
SOCK_DGRAM,
address.getFamily() != AF_UNIX ? IPPROTO_UDP : 0);
void AsyncUDPSocket::init(sa_family_t family) {
NetworkSocket socket =
netops::socket(family, SOCK_DGRAM, family != AF_UNIX ? IPPROTO_UDP : 0);
if (socket == NetworkSocket()) {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN,
......@@ -148,7 +146,7 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
}
// If we're using IPv6, make sure we don't accept V4-mapped connections
if (address.getFamily() == AF_INET6) {
if (family == AF_INET6) {
int flag = 1;
if (netops::setsockopt(
socket, IPPROTO_IPV6, IPV6_V6ONLY, &flag, sizeof(flag))) {
......@@ -157,25 +155,29 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
}
}
// success
g.dismiss();
fd_ = socket;
ownership_ = FDOwnership::OWNS;
// attach to EventHandler
EventHandler::changeHandlerFD(fd_);
}
void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
init(address.getFamily());
// bind to the address
sockaddr_storage addrStorage;
address.getAddress(&addrStorage);
auto& saddr = reinterpret_cast<sockaddr&>(addrStorage);
if (netops::bind(socket, &saddr, address.getActualSize()) != 0) {
if (netops::bind(fd_, &saddr, address.getActualSize()) != 0) {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN,
"failed to bind the async udp socket for:" + address.describe(),
errno);
}
// success
g.dismiss();
fd_ = socket;
ownership_ = FDOwnership::OWNS;
// attach to EventHandler
EventHandler::changeHandlerFD(fd_);
if (address.getFamily() == AF_UNIX || address.getPort() != 0) {
localAddress_ = address;
} else {
......@@ -183,17 +185,29 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
}
}
int AsyncUDPSocket::connect(const folly::SocketAddress& address) {
CHECK_NE(NetworkSocket(), fd_) << "Socket not yet bound";
void AsyncUDPSocket::connect(const folly::SocketAddress& address) {
// not bound yet
if (fd_ == NetworkSocket()) {
init(address.getFamily());
}
sockaddr_storage addrStorage;
address.getAddress(&addrStorage);
int ret = netops::connect(
fd_, reinterpret_cast<sockaddr*>(&addrStorage), address.getActualSize());
if (ret == 0) {
connected_ = true;
connectedAddress_ = address;
if (netops::connect(
fd_,
reinterpret_cast<sockaddr*>(&addrStorage),
address.getActualSize()) != 0) {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN,
"Failed to connect the udp socket to:" + address.describe(),
errno);
}
connected_ = true;
connectedAddress_ = address;
if (!localAddress_.isInitialized()) {
localAddress_.setFromLocalAddress(fd_);
}
return ret;
}
void AsyncUDPSocket::dontFragment(bool df) {
......
......@@ -152,7 +152,8 @@ class AsyncUDPSocket : public EventHandler {
* state on connects.
* Using connect has many quirks, and you should be aware of them before using
* this API:
* 1. This must only be called after binding the socket.
* 1. If this is called before bind, the socket will be automatically bound to
* the IP address of the current default network interface.
* 2. Normally UDP can use the 2 tuple (src ip, src port) to steer packets
* sent by the peer to the socket, however after connecting the socket, only
* packets destined to the destination address specified in connect() will be
......@@ -164,7 +165,7 @@ class AsyncUDPSocket : public EventHandler {
*
* Returns the result of calling the connect syscall.
*/
virtual int connect(const folly::SocketAddress& address);
virtual void connect(const folly::SocketAddress& address);
/**
* Use an already bound file descriptor. You can either transfer ownership
......@@ -440,6 +441,8 @@ class AsyncUDPSocket : public EventHandler {
AsyncUDPSocket(const AsyncUDPSocket&) = delete;
AsyncUDPSocket& operator=(const AsyncUDPSocket&) = delete;
void init(sa_family_t family);
// EventHandler
void handlerReady(uint16_t events) noexcept override;
......
......@@ -303,11 +303,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
}
void connect() {
int ret = socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
socket_->connect(*connectAddr_);
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
......
......@@ -212,11 +212,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
}
void connect() {
int ret = socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
socket_->connect(*connectAddr_);
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
......
......@@ -173,6 +173,8 @@ class UDPServer {
bool changePortForWrites_{true};
};
enum class BindSocket { YES, NO };
class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
public:
using AsyncUDPSocket::ReadCallback::OnDataAvailableParams;
......@@ -188,9 +190,12 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
socket_ = std::make_unique<AsyncUDPSocket>(evb_);
try {
socket_->bind(folly::SocketAddress("127.0.0.1", 0));
if (bindSocket_ == BindSocket::YES) {
socket_->bind(folly::SocketAddress("127.0.0.1", 0));
}
if (connectAddr_) {
connect();
socket_->connect(*connectAddr_);
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
VLOG(2) << "Client bound to " << socket_->address().describe();
} catch (const std::exception& ex) {
......@@ -209,15 +214,6 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
}
}
void connect() {
int ret = socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
void shutdown() {
CHECK(evb_->isInEventBaseThread());
socket_->pauseRead();
......@@ -295,8 +291,11 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
return *socket_;
}
void setShouldConnect(const folly::SocketAddress& connectAddr) {
void setShouldConnect(
const folly::SocketAddress& connectAddr,
BindSocket bindSocket) {
connectAddr_ = connectAddr;
bindSocket_ = bindSocket;
}
bool error() const {
......@@ -309,6 +308,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
protected:
folly::Optional<folly::SocketAddress> connectAddr_;
BindSocket bindSocket_{BindSocket::YES};
EventBase* const evb_{nullptr};
folly::SocketAddress server_;
......@@ -473,16 +473,19 @@ class AsyncSocketIntegrationTest : public Test {
std::unique_ptr<UDPClient> performPingPongTest(
folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress);
folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket = BindSocket::YES);
std::unique_ptr<UDPNotifyClient> performPingPongNotifyTest(
folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress);
folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket = BindSocket::YES);
std::unique_ptr<UDPNotifyClient> performPingPongNotifyMmsgTest(
folly::SocketAddress writeAddress,
unsigned int numMsgs,
folly::Optional<folly::SocketAddress> connectedAddress);
folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket = BindSocket::YES);
folly::EventBase sevb;
folly::EventBase cevb;
......@@ -492,10 +495,11 @@ class AsyncSocketIntegrationTest : public Test {
std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress) {
folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket) {
auto client = std::make_unique<UDPClient>(&cevb);
if (connectedAddress) {
client->setShouldConnect(*connectedAddress);
client->setShouldConnect(*connectedAddress, bindSocket);
}
// Start event loop in a separate thread
auto clientThread = std::thread([this]() { cevb.loopForever(); });
......@@ -514,10 +518,11 @@ std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
std::unique_ptr<UDPNotifyClient>
AsyncSocketIntegrationTest::performPingPongNotifyTest(
folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress) {
folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket) {
auto client = std::make_unique<UDPNotifyClient>(&cevb);
if (connectedAddress) {
client->setShouldConnect(*connectedAddress);
client->setShouldConnect(*connectedAddress, bindSocket);
}
// Start event loop in a separate thread
auto clientThread = std::thread([this]() { cevb.loopForever(); });
......@@ -537,10 +542,11 @@ std::unique_ptr<UDPNotifyClient>
AsyncSocketIntegrationTest::performPingPongNotifyMmsgTest(
folly::SocketAddress writeAddress,
unsigned int numMsgs,
folly::Optional<folly::SocketAddress> connectedAddress) {
folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket) {
auto client = std::make_unique<UDPNotifyClient>(&cevb, true, numMsgs);
if (connectedAddress) {
client->setShouldConnect(*connectedAddress);
client->setShouldConnect(*connectedAddress, bindSocket);
}
// Start event loop in a separate thread
auto clientThread = std::thread([this]() { cevb.loopForever(); });
......@@ -581,44 +587,63 @@ TEST_F(AsyncSocketIntegrationTest, PingPongNotifyMmsg) {
ASSERT_TRUE(pingClient->notifyInvoked);
}
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPong) {
class ConnectedAsyncSocketIntegrationTest
: public AsyncSocketIntegrationTest,
public WithParamInterface<BindSocket> {};
TEST_P(ConnectedAsyncSocketIntegrationTest, ConnectedPingPong) {
server->setChangePortForWrites(false);
startServer();
auto pingClient = performPingPongTest(server->address(), server->address());
auto pingClient =
performPingPongTest(server->address(), server->address(), GetParam());
// This should succeed
ASSERT_GT(pingClient->pongRecvd(), 0);
}
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongServerWrongAddress) {
TEST_P(
ConnectedAsyncSocketIntegrationTest,
ConnectedPingPongServerWrongAddress) {
server->setChangePortForWrites(true);
startServer();
auto pingClient = performPingPongTest(server->address(), server->address());
auto pingClient =
performPingPongTest(server->address(), server->address(), GetParam());
// This should fail.
ASSERT_EQ(pingClient->pongRecvd(), 0);
}
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongClientWrongAddress) {
TEST_P(
ConnectedAsyncSocketIntegrationTest,
ConnectedPingPongClientWrongAddress) {
server->setChangePortForWrites(false);
startServer();
folly::SocketAddress connectAddr(
server->address().getIPAddress(), server->address().getPort() + 1);
auto pingClient = performPingPongTest(server->address(), connectAddr);
auto pingClient =
performPingPongTest(server->address(), connectAddr, GetParam());
// This should fail.
ASSERT_EQ(pingClient->pongRecvd(), 0);
EXPECT_TRUE(pingClient->error());
}
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongDifferentWriteAddress) {
TEST_P(
ConnectedAsyncSocketIntegrationTest,
ConnectedPingPongDifferentWriteAddress) {
server->setChangePortForWrites(false);
startServer();
folly::SocketAddress connectAddr(
server->address().getIPAddress(), server->address().getPort() + 1);
auto pingClient = performPingPongTest(connectAddr, server->address());
auto pingClient =
performPingPongTest(connectAddr, server->address(), GetParam());
// This should fail.
ASSERT_EQ(pingClient->pongRecvd(), 0);
EXPECT_TRUE(pingClient->error());
}
INSTANTIATE_TEST_CASE_P(
ConnectedAsyncSocketIntegrationTests,
ConnectedAsyncSocketIntegrationTest,
Values(BindSocket::YES, BindSocket::NO));
TEST_F(AsyncSocketIntegrationTest, PingPongPauseResumeListening) {
startServer();
......@@ -703,8 +728,20 @@ class AsyncUDPSocketTest : public Test {
folly::SocketAddress addr_;
};
TEST_F(AsyncUDPSocketTest, TestConnectAfterBind) {
socket_->connect(addr_);
}
TEST_F(AsyncUDPSocketTest, TestConnect) {
EXPECT_EQ(socket_->connect(addr_), 0);
AsyncUDPSocket socket(&evb_);
EXPECT_FALSE(socket.isBound());
folly::SocketAddress address("127.0.0.1", 443);
socket.connect(address);
EXPECT_TRUE(socket.isBound());
const auto& localAddr = socket.address();
EXPECT_TRUE(localAddr.isInitialized());
EXPECT_GT(localAddr.getPort(), 0);
}
TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) {
......
......@@ -50,7 +50,7 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket {
MOCK_METHOD1(setReuseAddr, void(bool));
MOCK_METHOD1(dontFragment, void(bool));
MOCK_METHOD1(setErrMessageCallback, void(ErrMessageCallback*));
MOCK_METHOD1(connect, int(const SocketAddress&));
MOCK_METHOD1(connect, void(const SocketAddress&));
MOCK_CONST_METHOD0(isBound, bool());
MOCK_METHOD0(getGSO, int());
MOCK_METHOD1(setGSO, bool(int));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment