Commit b310ff2e authored by Dan Melnic's avatar Dan Melnic Committed by Facebook GitHub Bot

Add support for AsyncServerSocket bind to device

Summary:
Add support for AsyncServerSocket bind to device

(Note: this ignores all push blocking failures!)

Reviewed By: danobi

Differential Revision: D27778929

fbshipit-source-id: fa051be2aa5c1b3df5e0c0f30adc3a58de7b7704
parent bd600cd4
...@@ -300,11 +300,29 @@ void AsyncServerSocket::useExistingSocket(NetworkSocket fd) { ...@@ -300,11 +300,29 @@ void AsyncServerSocket::useExistingSocket(NetworkSocket fd) {
} }
void AsyncServerSocket::bindSocket( void AsyncServerSocket::bindSocket(
NetworkSocket fd, const SocketAddress& address, bool isExistingSocket) { NetworkSocket fd,
const SocketAddress& address,
bool isExistingSocket,
const std::string& ifName) {
sockaddr_storage addrStorage; sockaddr_storage addrStorage;
address.getAddress(&addrStorage); address.getAddress(&addrStorage);
auto saddr = reinterpret_cast<sockaddr*>(&addrStorage); auto saddr = reinterpret_cast<sockaddr*>(&addrStorage);
#if defined(__linux__)
if (!ifName.empty() &&
netops::setsockopt(
fd, SOL_SOCKET, SO_BINDTODEVICE, ifName.c_str(), ifName.length())) {
auto errnoCopy = errno;
if (!isExistingSocket) {
closeNoInt(fd);
}
folly::throwSystemErrorExplicit(
errnoCopy, "failed to bind to device: " + ifName);
}
#else
(void)ifName;
#endif
if (netops::bind(fd, saddr, address.getActualSize()) != 0) { if (netops::bind(fd, saddr, address.getActualSize()) != 0) {
if (errno != EINPROGRESS) { if (errno != EINPROGRESS) {
// Get a copy of errno so that it is not overwritten by subsequent calls. // Get a copy of errno so that it is not overwritten by subsequent calls.
...@@ -350,7 +368,8 @@ bool AsyncServerSocket::setZeroCopy(bool enable) { ...@@ -350,7 +368,8 @@ bool AsyncServerSocket::setZeroCopy(bool enable) {
return false; return false;
} }
void AsyncServerSocket::bind(const SocketAddress& address) { void AsyncServerSocket::bindInternal(
const SocketAddress& address, const std::string& ifName) {
if (eventBase_) { if (eventBase_) {
eventBase_->dcheckIsInEventBaseThread(); eventBase_->dcheckIsInEventBaseThread();
} }
...@@ -373,7 +392,16 @@ void AsyncServerSocket::bind(const SocketAddress& address) { ...@@ -373,7 +392,16 @@ void AsyncServerSocket::bind(const SocketAddress& address) {
throw std::invalid_argument("Attempted to bind to multiple fds"); throw std::invalid_argument("Attempted to bind to multiple fds");
} }
bindSocket(fd, address, !sockets_.empty()); bindSocket(fd, address, !sockets_.empty(), ifName);
}
void AsyncServerSocket::bind(const SocketAddress& address) {
bindInternal(address, "");
}
void AsyncServerSocket::bind(
const SocketAddress& address, const std::string& ifName) {
bindInternal(address, ifName);
} }
void AsyncServerSocket::bind( void AsyncServerSocket::bind(
...@@ -389,7 +417,28 @@ void AsyncServerSocket::bind( ...@@ -389,7 +417,28 @@ void AsyncServerSocket::bind(
SocketAddress address(ipAddress.toFullyQualified(), port); SocketAddress address(ipAddress.toFullyQualified(), port);
auto fd = createSocket(address.getFamily()); auto fd = createSocket(address.getFamily());
bindSocket(fd, address, false); bindSocket(fd, address, false, "");
}
if (sockets_.empty()) {
throw std::runtime_error(
"did not bind any async server socket for port and addresses");
}
}
void AsyncServerSocket::bind(
const std::vector<IPAddressIfNamePair>& addresses, uint16_t port) {
if (addresses.empty()) {
throw std::invalid_argument("No ip addresses were provided");
}
if (eventBase_) {
eventBase_->dcheckIsInEventBaseThread();
}
for (const auto& addr : addresses) {
SocketAddress address(addr.first.toFullyQualified(), port);
auto fd = createSocket(address.getFamily());
bindSocket(fd, address, false, addr.second);
} }
if (sockets_.empty()) { if (sockets_.empty()) {
throw std::runtime_error( throw std::runtime_error(
......
...@@ -338,6 +338,8 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -338,6 +338,8 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
*/ */
bool setZeroCopy(bool enable); bool setZeroCopy(bool enable);
using IPAddressIfNamePair = std::pair<IPAddress, std::string>;
/** /**
* Bind to the specified address. * Bind to the specified address.
* *
...@@ -347,6 +349,15 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -347,6 +349,15 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
*/ */
virtual void bind(const SocketAddress& address); virtual void bind(const SocketAddress& address);
/**
* Bind to the specified address/if name
*
* This must be called from the primary EventBase thread.
*
* Throws AsyncSocketException on error.
*/
virtual void bind(const SocketAddress& address, const std::string& ifName);
/** /**
* Bind to the specified port for the specified addresses. * Bind to the specified port for the specified addresses.
* *
...@@ -356,6 +367,16 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -356,6 +367,16 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
*/ */
virtual void bind(const std::vector<IPAddress>& ipAddresses, uint16_t port); virtual void bind(const std::vector<IPAddress>& ipAddresses, uint16_t port);
/**
* Bind to the specified port for the specified addresses/if names.
*
* This must be called from the primary EventBase thread.
*
* Throws AsyncSocketException on error.
*/
virtual void bind(
const std::vector<IPAddressIfNamePair>& addresses, uint16_t port);
/** /**
* Bind to the specified port. * Bind to the specified port.
* *
...@@ -829,8 +850,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -829,8 +850,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
NetworkSocket createSocket(int family); NetworkSocket createSocket(int family);
void setupSocket(NetworkSocket fd, int family); void setupSocket(NetworkSocket fd, int family);
void bindInternal(const SocketAddress& address, const std::string& ifName);
void bindSocket( void bindSocket(
NetworkSocket fd, const SocketAddress& address, bool isExistingSocket); NetworkSocket fd,
const SocketAddress& address,
bool isExistingSocket,
const std::string& ifName);
void dispatchSocket(NetworkSocket socket, SocketAddress&& address); void dispatchSocket(NetworkSocket socket, SocketAddress&& address);
void dispatchError(const char* msg, int errnoValue); void dispatchError(const char* msg, int errnoValue);
void enterBackoff(); void enterBackoff();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment