Unverified Commit 3a9c992b authored by Ali Güngör's avatar Ali Güngör Committed by GitHub

Merge pull request #444 from louisroyer/ipv6-ran

Allow RLS to work over IPv6
parents 59c13021 801d2fb3
...@@ -48,7 +48,7 @@ static nr::gnb::GnbConfig *ReadConfigYaml() ...@@ -48,7 +48,7 @@ static nr::gnb::GnbConfig *ReadConfigYaml()
result->gnbIdLength = yaml::GetInt32(config, "idLength", 22, 32); result->gnbIdLength = yaml::GetInt32(config, "idLength", 22, 32);
result->tac = yaml::GetInt32(config, "tac", 0, 0xFFFFFF); result->tac = yaml::GetInt32(config, "tac", 0, 0xFFFFFF);
result->portalIp = yaml::GetIp4(config, "linkIp"); result->portalIp = yaml::GetIp(config, "linkIp");
result->ngapIp = yaml::GetIp4(config, "ngapIp"); result->ngapIp = yaml::GetIp4(config, "ngapIp");
result->gtpIp = yaml::GetIp4(config, "gtpIp"); result->gtpIp = yaml::GetIp4(config, "gtpIp");
......
...@@ -9,31 +9,49 @@ ...@@ -9,31 +9,49 @@
#include "server.hpp" #include "server.hpp"
#include <cstring> #include <cstring>
#include <utils/common.hpp>
namespace udp namespace udp
{ {
UdpServer::UdpServer() : socket{Socket::CreateUdp4()} UdpServer::UdpServer(): sockets{}
{ {
sockets.push_back(Socket::CreateUdp6());
sockets.push_back(Socket::CreateUdp4());
} }
UdpServer::UdpServer(const std::string &address, uint16_t port) : socket{Socket::CreateAndBindUdp({address, port})} UdpServer::UdpServer(const std::string &address, uint16_t port): sockets{}
{ {
sockets.push_back(Socket::CreateAndBindUdp({address, port}));
} }
int UdpServer::Receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddress &outPeerAddress) const int UdpServer::Receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddress &outPeerAddress)
{ {
return socket.receive(buffer, bufferSize, timeoutMs, outPeerAddress); // Choose at random a ready socket for receiving data
std::vector<Socket> ws;
return Socket::Select(sockets, ws, timeoutMs).receive(buffer, bufferSize, 0, outPeerAddress);
} }
void UdpServer::Send(const InetAddress &address, const uint8_t *buffer, size_t bufferSize) const int UdpServer::Send(const InetAddress &address, const uint8_t *buffer, size_t bufferSize) const
{ {
socket.send(address, buffer, bufferSize); int version = address.getIpVersion();
// invalid family
if (!version)
return -1;
// send on first socket matching ip version
for(const Socket &s : sockets)
{
if (s.getIpVersion() == version)
return s.send(address, buffer, bufferSize);
}
// no socket found
return -1;
} }
UdpServer::~UdpServer() UdpServer::~UdpServer()
{ {
socket.close(); for (Socket &s : sockets)
s.close();
} }
} // namespace udp } // namespace udp
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unistd.h>
#include <utils/network.hpp> #include <utils/network.hpp>
...@@ -18,15 +19,15 @@ namespace udp ...@@ -18,15 +19,15 @@ namespace udp
class UdpServer class UdpServer
{ {
private: private:
Socket socket; std::vector<Socket> sockets;
public: public:
UdpServer(); UdpServer();
UdpServer(const std::string &address, uint16_t port); UdpServer(const std::string &address, uint16_t port);
~UdpServer(); ~UdpServer();
int Receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddress &outPeerAddress) const; int Receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddress &outPeerAddress);
void Send(const InetAddress &address, const uint8_t *buffer, size_t bufferSize) const; int Send(const InetAddress &address, const uint8_t *buffer, size_t bufferSize) const;
}; };
} // namespace udp } // namespace udp
...@@ -30,7 +30,7 @@ static_assert(sizeof(long long) == sizeof(uint64_t)); ...@@ -30,7 +30,7 @@ static_assert(sizeof(long long) == sizeof(uint64_t));
static std::atomic<int> g_idCounter = 1; static std::atomic<int> g_idCounter = 1;
static bool IPv6FromString(const char *szAddress, uint8_t *address) static bool IPv6FromString(const char *szAddress, std::vector<uint8_t>& address)
{ {
auto asciiToHex = [](char c) -> int { auto asciiToHex = [](char c) -> int {
c |= 0x20; c |= 0x20;
...@@ -46,7 +46,9 @@ static bool IPv6FromString(const char *szAddress, uint8_t *address) ...@@ -46,7 +46,9 @@ static bool IPv6FromString(const char *szAddress, uint8_t *address)
uint8_t colons = 0; uint8_t colons = 0;
uint8_t pos = 0; uint8_t pos = 0;
memset(address, 0, 16); address.clear();
std::vector<uint8_t> emptyAddress{16};
address.insert(address.begin(), emptyAddress.begin(), emptyAddress.end());
for (uint8_t i = 1; i <= 39; i++) for (uint8_t i = 1; i <= 39; i++)
{ {
...@@ -60,12 +62,12 @@ static bool IPv6FromString(const char *szAddress, uint8_t *address) ...@@ -60,12 +62,12 @@ static bool IPv6FromString(const char *szAddress, uint8_t *address)
else if (szAddress[i] == '\0') else if (szAddress[i] == '\0')
break; break;
} }
for (uint8_t i = 0; i <= 39 && pos < 16; i++) for (uint8_t i = 0; i <= 39 && pos < address.size(); i++)
{ {
if (szAddress[i] == ':' || szAddress[i] == '\0') if (szAddress[i] == ':' || szAddress[i] == '\0')
{ {
address[pos] = acc >> 8; address.at(pos) = acc >> 8;
address[pos + 1] = acc; address.at(pos + 1) = acc;
acc = 0; acc = 0;
if (colons && i && szAddress[i - 1] == ':') if (colons && i && szAddress[i - 1] == ':')
...@@ -197,7 +199,7 @@ OctetString utils::IpToOctetString(const std::string &address) ...@@ -197,7 +199,7 @@ OctetString utils::IpToOctetString(const std::string &address)
else if (ipVersion == 6) else if (ipVersion == 6)
{ {
std::vector<uint8_t> data{16}; std::vector<uint8_t> data{16};
if (!IPv6FromString(address.c_str(), data.data())) if (!IPv6FromString(address.c_str(), data))
return {}; return {};
return OctetString(std::move(data)); return OctetString(std::move(data));
} }
......
...@@ -219,6 +219,36 @@ std::string GetIp4OfInterface(const std::string &ifName) ...@@ -219,6 +219,36 @@ std::string GetIp4OfInterface(const std::string &ifName)
return std::string{str}; return std::string{str};
} }
std::string GetIp6OfInterface(const std::string &ifName)
{
std::string res;
struct ifreq ifr = {};
int fd = socket(AF_INET6, SOCK_DGRAM, 0);
if (fd <= 0)
return "";
ifr.ifr_addr.sa_family = AF_INET6;
strncpy(ifr.ifr_name, ifName.c_str(), IFNAMSIZ - 1);
if (ioctl(fd, SIOCGIFADDR, &ifr))
{
close(fd);
return "";
}
close(fd);
auto address = ((struct sockaddr_in *)&ifr.ifr_addr)->sin_addr;
char str[INET6_ADDRSTRLEN] = {0};
if (inet_ntop(AF_INET6, &address, str, INET6_ADDRSTRLEN) == nullptr)
return "";
return std::string{str};
}
std::string GetHostByName(const std::string &name) std::string GetHostByName(const std::string &name)
{ {
struct addrinfo hints = {}; struct addrinfo hints = {};
......
...@@ -42,6 +42,8 @@ void AppendPath(std::string &source, const std::string &target); ...@@ -42,6 +42,8 @@ void AppendPath(std::string &source, const std::string &target);
std::string GetIp4OfInterface(const std::string &ifName); std::string GetIp4OfInterface(const std::string &ifName);
std::string GetIp6OfInterface(const std::string &ifName);
std::string GetHostByName(const std::string& name); std::string GetHostByName(const std::string& name);
} // namespace io } // namespace io
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h> #include <netdb.h>
#include <random>
#include <stdexcept> #include <stdexcept>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/types.h> #include <sys/types.h>
...@@ -78,7 +79,7 @@ InetAddress::InetAddress(const std::string &address, uint16_t port) : storage{}, ...@@ -78,7 +79,7 @@ InetAddress::InetAddress(const std::string &address, uint16_t port) : storage{},
if (s != 0) if (s != 0)
throw LibError("Bad Inet address: " + address, errno); throw LibError("Bad Inet address: " + address, errno);
if (result->ai_family != AF_INET && result->ai_family == AF_INET6) if (result->ai_family != AF_INET && result->ai_family != AF_INET6)
{ {
freeaddrinfo(result); freeaddrinfo(result);
throw std::runtime_error("Bad Inet address: " + address); throw std::runtime_error("Bad Inet address: " + address);
...@@ -89,6 +90,16 @@ InetAddress::InetAddress(const std::string &address, uint16_t port) : storage{}, ...@@ -89,6 +90,16 @@ InetAddress::InetAddress(const std::string &address, uint16_t port) : storage{},
freeaddrinfo(result); freeaddrinfo(result);
} }
int InetAddress::getIpVersion() const
{
if (storage.ss_family == AF_INET)
return 4;
else if (storage.ss_family == AF_INET6)
return 6;
else
return 0;
}
InetAddress::InetAddress(const OctetString &address, uint16_t port) : InetAddress(OctetStringToIpString(address), port) InetAddress::InetAddress(const OctetString &address, uint16_t port) : InetAddress(OctetStringToIpString(address), port)
{ {
} }
...@@ -111,6 +122,7 @@ uint16_t InetAddress::getPort() const ...@@ -111,6 +122,7 @@ uint16_t InetAddress::getPort() const
Socket::Socket(int domain, int type, int protocol) Socket::Socket(int domain, int type, int protocol)
{ {
int sd = socket(domain, type, protocol); int sd = socket(domain, type, protocol);
socketDomain = domain;
if (sd < 0) if (sd < 0)
throw LibError("Socket could not be created:", errno); throw LibError("Socket could not be created:", errno);
this->fd = sd; this->fd = sd;
...@@ -185,7 +197,7 @@ int Socket::receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddre ...@@ -185,7 +197,7 @@ int Socket::receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddre
return 0; return 0;
} }
void Socket::send(const InetAddress &address, const uint8_t *buffer, size_t size) const int Socket::send(const InetAddress &address, const uint8_t *buffer, size_t size) const
{ {
ssize_t rc = sendto(fd, buffer, size, MSG_DONTWAIT, address.getSockAddr(), address.getSockLen()); ssize_t rc = sendto(fd, buffer, size, MSG_DONTWAIT, address.getSockAddr(), address.getSockLen());
if (rc == -1) if (rc == -1)
...@@ -194,6 +206,7 @@ void Socket::send(const InetAddress &address, const uint8_t *buffer, size_t size ...@@ -194,6 +206,7 @@ void Socket::send(const InetAddress &address, const uint8_t *buffer, size_t size
if (err != EAGAIN) if (err != EAGAIN)
throw LibError("sendto failed: ", errno); throw LibError("sendto failed: ", errno);
} }
return rc;
} }
bool Socket::hasFd() const bool Socket::hasFd() const
...@@ -261,10 +274,20 @@ Socket Socket::Select(const std::vector<Socket> &readSockets, const std::vector< ...@@ -261,10 +274,20 @@ Socket Socket::Select(const std::vector<Socket> &readSockets, const std::vector<
std::vector<Socket> rs, ws; std::vector<Socket> rs, ws;
Select(readSockets, writeSockets, rs, ws, timeout); Select(readSockets, writeSockets, rs, ws, timeout);
// Return a socket choosen at random from selection
// to avoid starvation
std::default_random_engine generator;
if (!rs.empty()) if (!rs.empty())
return rs[0]; {
std::uniform_int_distribution<int> drs(0, rs.size()-1);
return rs[drs(generator)];
}
if (!ws.empty()) if (!ws.empty())
return rs[0]; {
std::uniform_int_distribution<int> dws(0, ws.size()-1);
return rs[dws(generator)];
}
return {}; return {};
} }
...@@ -297,3 +320,13 @@ InetAddress Socket::getAddress() const ...@@ -297,3 +320,13 @@ InetAddress Socket::getAddress() const
return {storage, len}; return {storage, len};
} }
int Socket::getIpVersion() const
{
if (socketDomain == AF_INET6)
return 6;
else if (socketDomain == AF_INET)
return 4;
else
return 0;
}
...@@ -38,6 +38,7 @@ struct InetAddress ...@@ -38,6 +38,7 @@ struct InetAddress
return len; return len;
} }
[[nodiscard]] int getIpVersion() const;
[[nodiscard]] uint16_t getPort() const; [[nodiscard]] uint16_t getPort() const;
}; };
...@@ -45,6 +46,7 @@ class Socket ...@@ -45,6 +46,7 @@ class Socket
{ {
private: private:
int fd; int fd;
int socketDomain;
public: public:
Socket(); Socket();
...@@ -53,10 +55,11 @@ class Socket ...@@ -53,10 +55,11 @@ class Socket
public: public:
void bind(const InetAddress &address) const; void bind(const InetAddress &address) const;
int receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddress &outAddress) const; int receive(uint8_t *buffer, size_t bufferSize, int timeoutMs, InetAddress &outAddress) const;
void send(const InetAddress &address, const uint8_t *buffer, size_t size) const; int send(const InetAddress &address, const uint8_t *buffer, size_t size) const;
void close(); void close();
[[nodiscard]] bool hasFd() const; [[nodiscard]] bool hasFd() const;
[[nodiscard]] InetAddress getAddress() const; [[nodiscard]] InetAddress getAddress() const;
[[nodiscard]] int getIpVersion() const;
/* Socket options */ /* Socket options */
void setReuseAddress() const; void setReuseAddress() const;
......
...@@ -158,6 +158,24 @@ std::string GetIp4(const YAML::Node &node, const std::string &name) ...@@ -158,6 +158,24 @@ std::string GetIp4(const YAML::Node &node, const std::string &name)
return ipFromIf; return ipFromIf;
} }
std::string GetIp(const YAML::Node &node, const std::string & name)
{
std::string s = GetString(node, name);
int version = utils::GetIpVersion(s);
if (version == 6 || version == 4)
return s;
auto ip4FromIf = io::GetIp4OfInterface(s);
if (!ip4FromIf.empty())
return ip4FromIf;
auto ip6FromIf = io::GetIp6OfInterface(s);
if (!ip6FromIf.empty())
return ip6FromIf;
FieldError(name, "must be a valid IP address or a valid network interface with an IP address");
}
void AssertHasBool(const YAML::Node &node, const std::string &name) void AssertHasBool(const YAML::Node &node, const std::string &name)
{ {
AssertHasField(node, name); AssertHasField(node, name);
......
...@@ -42,6 +42,7 @@ std::string GetString(const YAML::Node &node, const std::string &name, std::opti ...@@ -42,6 +42,7 @@ std::string GetString(const YAML::Node &node, const std::string &name, std::opti
std::optional<int> maxLength); std::optional<int> maxLength);
std::string GetIp4(const YAML::Node &node, const std::string &name); std::string GetIp4(const YAML::Node &node, const std::string &name);
std::string GetIp(const YAML::Node &node, const std::string &name);
bool GetBool(const YAML::Node &node, const std::string &name); bool GetBool(const YAML::Node &node, const std::string &name);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment