Commit 35594e09 authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

Merge branch 'nghttpx-more-block-allocator'

parents 68a6d8c5 96ff3be5
......@@ -29,6 +29,9 @@
#include <string>
#include "template.h"
#include "allocator.h"
namespace nghttp2 {
namespace base64 {
......@@ -87,7 +90,8 @@ InputIt next_decode_input(InputIt first, InputIt last, const int *tbl) {
return first;
}
template <typename InputIt> std::string decode(InputIt first, InputIt last) {
template <typename InputIt, typename OutputIt>
OutputIt decode(InputIt first, InputIt last, OutputIt d_first) {
static constexpr int INDEX_TABLE[] = {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
......@@ -104,37 +108,29 @@ template <typename InputIt> std::string decode(InputIt first, InputIt last) {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1};
auto len = last - first;
if (len % 4 != 0) {
return "";
}
std::string res;
res.resize(len / 4 * 3);
auto p = std::begin(res);
assert(std::distance(first, last) % 4 == 0);
auto p = d_first;
for (; first != last;) {
uint32_t n = 0;
for (int i = 1; i <= 4; ++i, ++first) {
auto idx = INDEX_TABLE[static_cast<size_t>(*first)];
if (idx == -1) {
if (i <= 2) {
return "";
return d_first;
}
if (i == 3) {
if (*first == '=' && *(first + 1) == '=' && first + 2 == last) {
*p++ = n >> 16;
res.resize(p - std::begin(res));
return res;
return p;
}
return "";
return d_first;
}
if (*first == '=' && first + 1 == last) {
*p++ = n >> 16;
*p++ = n >> 8 & 0xffu;
res.resize(p - std::begin(res));
return res;
return p;
}
return "";
return d_first;
}
n += idx << (24 - i * 6);
......@@ -145,9 +141,37 @@ template <typename InputIt> std::string decode(InputIt first, InputIt last) {
*p++ = n & 0xffu;
}
return p;
}
template <typename InputIt> std::string decode(InputIt first, InputIt last) {
auto len = std::distance(first, last);
if (len % 4 != 0) {
return "";
}
std::string res;
res.resize(len / 4 * 3);
res.erase(decode(first, last, std::begin(res)), std::end(res));
return res;
}
template <typename InputIt>
StringRef decode(BlockAllocator &balloc, InputIt first, InputIt last) {
auto len = std::distance(first, last);
if (len % 4 != 0) {
return StringRef::from_lit("");
}
auto iov = make_byte_ref(balloc, len / 4 * 3 + 1);
auto p = iov.base;
p = decode(first, last, p);
*p = '\0';
return StringRef{iov.base, p};
}
} // namespace base64
} // namespace nghttp2
......
......@@ -59,31 +59,40 @@ void test_base64_encode(void) {
}
void test_base64_decode(void) {
BlockAllocator balloc(4096, 4096);
{
std::string in = "/w==";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff" == out);
CU_ASSERT("\xff" == base64::decode(balloc, std::begin(in), std::end(in)));
}
{
std::string in = "//4=";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff\xfe" == out);
CU_ASSERT("\xff\xfe" ==
base64::decode(balloc, std::begin(in), std::end(in)));
}
{
std::string in = "//79";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff\xfe\xfd" == out);
CU_ASSERT("\xff\xfe\xfd" ==
base64::decode(balloc, std::begin(in), std::end(in)));
}
{
std::string in = "//79/A==";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff\xfe\xfd\xfc" == out);
CU_ASSERT("\xff\xfe\xfd\xfc" ==
base64::decode(balloc, std::begin(in), std::end(in)));
}
{
// we check the number of valid input must be multiples of 4
std::string in = "//79=";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
}
{
// ending invalid character at the boundary of multiples of 4 is
......@@ -91,18 +100,21 @@ void test_base64_decode(void) {
std::string in = "bmdodHRw\n";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
}
{
// after seeing '=', subsequent input must be also '='.
std::string in = "//79/A=A";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
}
{
// additional '=' at the end is bad
std::string in = "//79/A======";
auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
}
}
......
......@@ -176,6 +176,9 @@ int main(int argc, char *argv[]) {
!CU_add_test(pSuite, "util_make_hostport",
shrpx::test_util_make_hostport) ||
!CU_add_test(pSuite, "util_strifind", shrpx::test_util_strifind) ||
!CU_add_test(pSuite, "util_random_alpha_digit",
shrpx::test_util_random_alpha_digit) ||
!CU_add_test(pSuite, "util_format_hex", shrpx::test_util_format_hex) ||
!CU_add_test(pSuite, "gzip_inflate", test_nghttp2_gzip_inflate) ||
!CU_add_test(pSuite, "buffer_write", nghttp2::test_buffer_write) ||
!CU_add_test(pSuite, "pool_recycle", nghttp2::test_pool_recycle) ||
......
......@@ -86,6 +86,7 @@
#include "app_helper.h"
#include "ssl.h"
#include "template.h"
#include "allocator.h"
extern char **environ;
......@@ -151,7 +152,7 @@ StartupConfig suconfig;
struct InheritedAddr {
// IP address if TCP socket. Otherwise, UNIX domain socket path.
ImmutableString host;
StringRef host;
uint16_t port;
// true if UNIX domain socket path
bool host_unix;
......@@ -574,7 +575,7 @@ int create_unix_domain_server_socket(UpstreamAddr &faddr,
<< (faddr.tls ? ", tls" : "");
(*found).used = true;
faddr.fd = (*found).fd;
faddr.hostport = ImmutableString::from_lit("localhost");
faddr.hostport = StringRef::from_lit("localhost");
return 0;
}
......@@ -639,7 +640,7 @@ int create_unix_domain_server_socket(UpstreamAddr &faddr,
<< (faddr.tls ? ", tls" : "");
faddr.fd = fd;
faddr.hostport = ImmutableString::from_lit("localhost");
faddr.hostport = StringRef::from_lit("localhost");
return 0;
}
......@@ -791,8 +792,8 @@ int create_tcp_server_socket(UpstreamAddr &faddr,
}
faddr.fd = fd;
faddr.hostport = ImmutableString{
util::make_http_hostport(StringRef{host.data()}, faddr.port)};
faddr.hostport = util::make_http_hostport(mod_config()->balloc,
StringRef{host.data()}, faddr.port);
LOG(NOTICE) << "Listening on " << faddr.hostport
<< (faddr.tls ? ", tls" : "");
......@@ -806,7 +807,7 @@ namespace {
// function is intended to be used when reloading configuration, and
// |config| is usually a current configuration.
std::vector<InheritedAddr>
get_inherited_addr_from_config(const Config *config) {
get_inherited_addr_from_config(BlockAllocator &balloc, Config *config) {
int rv;
auto &listenerconf = config->conn.listener;
......@@ -856,7 +857,7 @@ get_inherited_addr_from_config(const Config *config) {
continue;
}
iaddr.host = ImmutableString{host.data()};
iaddr.host = make_string_ref(balloc, StringRef{host.data()});
}
return iaddrs;
......@@ -867,7 +868,7 @@ namespace {
// Returns array of InheritedAddr constructed from environment
// variables. This function handles the old environment variable
// names used in 1.7.0 or earlier.
std::vector<InheritedAddr> get_inherited_addr_from_env() {
std::vector<InheritedAddr> get_inherited_addr_from_env(Config *config) {
int rv;
std::vector<InheritedAddr> iaddrs;
......@@ -888,15 +889,14 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() {
}
}
} else {
auto pathenv = getenv(ENV_UNIX_PATH.c_str());
auto fdenv = getenv(ENV_UNIX_FD.c_str());
if (pathenv && fdenv) {
// The return value of getenv may be allocated statically.
if (getenv(ENV_UNIX_PATH.c_str()) && getenv(ENV_UNIX_FD.c_str())) {
auto name = ENV_ACCEPT_PREFIX.str();
name += '1';
std::string value = "unix,";
value += fdenv;
value += getenv(ENV_UNIX_FD.c_str());
value += ',';
value += pathenv;
value += getenv(ENV_UNIX_PATH.c_str());
setenv(name.c_str(), value.c_str(), 0);
}
}
......@@ -948,7 +948,7 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() {
}
InheritedAddr addr{};
addr.host = ImmutableString{path};
addr.host = make_string_ref(config->balloc, StringRef{path});
addr.host_unix = true;
addr.fd = static_cast<int>(fd);
iaddrs.push_back(std::move(addr));
......@@ -1002,7 +1002,7 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() {
}
InheritedAddr addr{};
addr.host = ImmutableString{host.data()};
addr.host = make_string_ref(config->balloc, StringRef{host.data()});
addr.port = static_cast<uint16_t>(port);
addr.fd = static_cast<int>(fd);
iaddrs.push_back(std::move(addr));
......@@ -1209,7 +1209,7 @@ int event_loop() {
redirect_stderr_to_errorlog();
}
auto iaddrs = get_inherited_addr_from_env();
auto iaddrs = get_inherited_addr_from_env(mod_config());
if (create_acceptor_socket(mod_config(), iaddrs) != 0) {
return -1;
......@@ -1274,7 +1274,7 @@ constexpr auto DEFAULT_ACCESSLOG_FORMAT = StringRef::from_lit(
namespace {
void fill_default_config(Config *config) {
config->num_worker = 1;
config->conf_path = ImmutableString::from_lit("/etc/nghttpx/nghttpx.conf");
config->conf_path = StringRef::from_lit("/etc/nghttpx/nghttpx.conf");
config->pid = getpid();
if (ev_supported_backends() & ~ev_recommended_backends() & EVBACKEND_KQUEUE) {
......@@ -1306,7 +1306,7 @@ void fill_default_config(Config *config) {
// ocsp update interval = 14400 secs = 4 hours, borrowed from h2o
ocspconf.update_interval = 4_h;
ocspconf.fetch_ocsp_response_file =
ImmutableString::from_lit(PKGDATADIR "/fetch-ocsp-response");
StringRef::from_lit(PKGDATADIR "/fetch-ocsp-response");
}
{
......@@ -1319,7 +1319,7 @@ void fill_default_config(Config *config) {
auto &httpconf = config->http;
httpconf.server_name =
ImmutableString::from_lit("nghttpx nghttp2/" NGHTTP2_VERSION);
StringRef::from_lit("nghttpx nghttp2/" NGHTTP2_VERSION);
httpconf.no_host_rewrite = true;
httpconf.request_header_field_buffer = 64_k;
httpconf.max_request_header_fields = 100;
......@@ -1387,10 +1387,11 @@ void fill_default_config(Config *config) {
auto &loggingconf = config->logging;
{
auto &accessconf = loggingconf.access;
accessconf.format = parse_log_format(DEFAULT_ACCESSLOG_FORMAT);
accessconf.format =
parse_log_format(config->balloc, DEFAULT_ACCESSLOG_FORMAT);
auto &errorconf = loggingconf.error;
errorconf.file = ImmutableString::from_lit("/dev/stderr");
errorconf.file = StringRef::from_lit("/dev/stderr");
}
loggingconf.syslog_facility = LOG_DAEMON;
......@@ -2446,11 +2447,10 @@ int process_options(Config *config,
auto &tlsconf = config->tls;
if (tlsconf.npn_list.empty()) {
tlsconf.npn_list = util::parse_config_str_list(DEFAULT_NPN_LIST);
tlsconf.npn_list = util::split_str(DEFAULT_NPN_LIST, ',');
}
if (tlsconf.tls_proto_list.empty()) {
tlsconf.tls_proto_list =
util::parse_config_str_list(DEFAULT_TLS_PROTO_LIST);
tlsconf.tls_proto_list = util::split_str(DEFAULT_TLS_PROTO_LIST, ',');
}
tlsconf.tls_proto_mask = ssl::create_tls_proto_mask(tlsconf.tls_proto_list);
......@@ -2466,7 +2466,7 @@ int process_options(Config *config,
if (listenerconf.addrs.empty()) {
UpstreamAddr addr{};
addr.host = ImmutableString::from_lit("*");
addr.host = StringRef::from_lit("*");
addr.port = 3000;
addr.tls = true;
addr.family = AF_INET;
......@@ -2567,10 +2567,14 @@ int process_options(Config *config,
if (fwdconf.by_node_type == FORWARDED_NODE_OBFUSCATED &&
fwdconf.by_obfuscated.empty()) {
// 2 for '_' and terminal NULL
auto iov = make_byte_ref(config->balloc, SHRPX_OBFUSCATED_NODE_LENGTH + 2);
auto p = iov.base;
*p++ = '_';
std::mt19937 gen(rd());
auto &dst = fwdconf.by_obfuscated;
dst = "_";
dst += util::random_alpha_digit(gen, SHRPX_OBFUSCATED_NODE_LENGTH);
p = util::random_alpha_digit(p, p + SHRPX_OBFUSCATED_NODE_LENGTH, gen);
*p = '\0';
fwdconf.by_obfuscated = StringRef{iov.base, p};
}
if (config->http2.upstream.debug.frame_debug) {
......@@ -2616,12 +2620,13 @@ void reload_config(WorkerProcess *wp) {
LOG(NOTICE) << "Reloading configuration";
auto cur_config = get_config();
auto cur_config = mod_config();
auto new_config = make_unique<Config>();
fill_default_config(new_config.get());
new_config->conf_path = cur_config->conf_path;
new_config->conf_path =
make_string_ref(new_config->balloc, cur_config->conf_path);
// daemon option is ignored here.
new_config->daemon = cur_config->daemon;
// loop is reused, and ev_loop_flags gets ignored
......@@ -2633,7 +2638,7 @@ void reload_config(WorkerProcess *wp) {
return;
}
auto iaddrs = get_inherited_addr_from_config(cur_config);
auto iaddrs = get_inherited_addr_from_config(new_config->balloc, cur_config);
if (create_acceptor_socket(new_config.get(), iaddrs) != 0) {
close_not_inherited_fd(new_config.get(), iaddrs);
......@@ -3027,7 +3032,8 @@ int main(int argc, char **argv) {
break;
case 12:
// --conf
mod_config()->conf_path = ImmutableString{optarg};
mod_config()->conf_path =
make_string_ref(mod_config()->balloc, StringRef{optarg});
break;
case 14:
// --syslog-facility
......
This diff is collapsed.
......@@ -37,6 +37,7 @@
#include "shrpx_connection.h"
#include "buffer.h"
#include "memchunk.h"
#include "allocator.h"
using namespace nghttp2;
......@@ -54,8 +55,8 @@ struct DownstreamAddr;
class ClientHandler {
public:
ClientHandler(Worker *worker, int fd, SSL *ssl, const char *ipaddr,
const char *port, int family, const UpstreamAddr *faddr);
ClientHandler(Worker *worker, int fd, SSL *ssl, const StringRef &ipaddr,
const StringRef &port, int family, const UpstreamAddr *faddr);
~ClientHandler();
int noop();
......@@ -90,8 +91,7 @@ public:
void reset_upstream_write_timeout(ev_tstamp t);
int validate_next_proto();
const std::string &get_ipaddr() const;
const std::string &get_port() const;
const StringRef &get_ipaddr() const;
bool get_should_close_after_write() const;
void set_should_close_after_write(bool f);
Upstream *get_upstream();
......@@ -162,21 +162,27 @@ public:
// Returns TLS SNI extension value client sent in this connection.
StringRef get_tls_sni() const;
BlockAllocator &get_block_allocator();
private:
// Allocator to allocate memory for connection-wide objects. Make
// sure that the allocations must be bounded, and not proportional
// to the number of requests.
BlockAllocator balloc_;
Connection conn_;
ev_timer reneg_shutdown_timer_;
std::unique_ptr<Upstream> upstream_;
// IP address of client. If UNIX domain socket is used, this is
// "localhost".
std::string ipaddr_;
std::string port_;
StringRef ipaddr_;
StringRef port_;
// The ALPN identifier negotiated for this connection.
std::string alpn_;
StringRef alpn_;
// The client address used in "for" parameter of Forwarded header
// field.
std::string forwarded_for_;
StringRef forwarded_for_;
// lowercased TLS SNI which client sent.
std::string sni_;
StringRef sni_;
std::function<int(ClientHandler &)> read_, write_;
std::function<int(ClientHandler &)> on_read_, on_write_;
// Address of frontend listening socket
......
This diff is collapsed.
This diff is collapsed.
......@@ -37,40 +37,45 @@
namespace shrpx {
void test_shrpx_config_parse_header(void) {
auto p = parse_header(StringRef::from_lit("a: b"));
BlockAllocator balloc(4096, 4096);
auto p = parse_header(balloc, StringRef::from_lit("a: b"));
CU_ASSERT("a" == p.name);
CU_ASSERT("b" == p.value);
p = parse_header(StringRef::from_lit("a: b"));
p = parse_header(balloc, StringRef::from_lit("a: b"));
CU_ASSERT("a" == p.name);
CU_ASSERT("b" == p.value);
p = parse_header(StringRef::from_lit(":a: b"));
p = parse_header(balloc, StringRef::from_lit(":a: b"));
CU_ASSERT(p.name.empty());
p = parse_header(StringRef::from_lit("a: :b"));
p = parse_header(balloc, StringRef::from_lit("a: :b"));
CU_ASSERT("a" == p.name);
CU_ASSERT(":b" == p.value);
p = parse_header(StringRef::from_lit(": b"));
p = parse_header(balloc, StringRef::from_lit(": b"));
CU_ASSERT(p.name.empty());
p = parse_header(StringRef::from_lit("alpha: bravo charlie"));
p = parse_header(balloc, StringRef::from_lit("alpha: bravo charlie"));
CU_ASSERT("alpha" == p.name);
CU_ASSERT("bravo charlie" == p.value);
p = parse_header(StringRef::from_lit("a,: b"));
p = parse_header(balloc, StringRef::from_lit("a,: b"));
CU_ASSERT(p.name.empty());
p = parse_header(StringRef::from_lit("a: b\x0a"));
p = parse_header(balloc, StringRef::from_lit("a: b\x0a"));
CU_ASSERT(p.name.empty());
}
void test_shrpx_config_parse_log_format(void) {
auto res = parse_log_format(StringRef::from_lit(
R"($remote_addr - $remote_user [$time_local] )"
R"("$request" $status $body_bytes_sent )"
R"("${http_referer}" $http_host "$http_user_agent")"));
BlockAllocator balloc(4096, 4096);
auto res = parse_log_format(
balloc, StringRef::from_lit(
R"($remote_addr - $remote_user [$time_local] )"
R"("$request" $status $body_bytes_sent )"
R"("${http_referer}" $http_host "$http_user_agent")"));
CU_ASSERT(16 == res.size());
CU_ASSERT(SHRPX_LOGF_REMOTE_ADDR == res[0].type);
......@@ -115,35 +120,35 @@ void test_shrpx_config_parse_log_format(void) {
CU_ASSERT(SHRPX_LOGF_LITERAL == res[15].type);
CU_ASSERT("\"" == res[15].value);
res = parse_log_format(StringRef::from_lit("$"));
res = parse_log_format(balloc, StringRef::from_lit("$"));
CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("$" == res[0].value);
res = parse_log_format(StringRef::from_lit("${"));
res = parse_log_format(balloc, StringRef::from_lit("${"));
CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${" == res[0].value);
res = parse_log_format(StringRef::from_lit("${a"));
res = parse_log_format(balloc, StringRef::from_lit("${a"));
CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${a" == res[0].value);
res = parse_log_format(StringRef::from_lit("${a "));
res = parse_log_format(balloc, StringRef::from_lit("${a "));
CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${a " == res[0].value);
res = parse_log_format(StringRef::from_lit("$$remote_addr"));
res = parse_log_format(balloc, StringRef::from_lit("$$remote_addr"));
CU_ASSERT(2 == res.size());
......@@ -168,8 +173,8 @@ void test_shrpx_config_read_tls_ticket_key_file(void) {
close(fd1);
close(fd2);
auto ticket_keys =
read_tls_ticket_key_file({file1, file2}, EVP_aes_128_cbc(), EVP_sha256());
auto ticket_keys = read_tls_ticket_key_file(
{StringRef{file1}, StringRef{file2}}, EVP_aes_128_cbc(), EVP_sha256());
unlink(file1);
unlink(file2);
CU_ASSERT(ticket_keys.get() != nullptr);
......@@ -211,8 +216,8 @@ void test_shrpx_config_read_tls_ticket_key_file_aes_256(void) {
close(fd1);
close(fd2);
auto ticket_keys =
read_tls_ticket_key_file({file1, file2}, EVP_aes_256_cbc(), EVP_sha256());
auto ticket_keys = read_tls_ticket_key_file(
{StringRef{file1}, StringRef{file2}}, EVP_aes_256_cbc(), EVP_sha256());
unlink(file1);
unlink(file2);
CU_ASSERT(ticket_keys.get() != nullptr);
......
......@@ -224,8 +224,8 @@ int ConnectionHandler::create_single_worker() {
#ifdef HAVE_NEVERBLEED
nb_.get(),
#endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{memcachedconf.cert_file},
StringRef{memcachedconf.private_key_file}, nullptr);
tlsconf.cacert, memcachedconf.cert_file, memcachedconf.private_key_file,
nullptr);
all_ssl_ctx_.push_back(session_cache_ssl_ctx);
}
......@@ -280,8 +280,8 @@ int ConnectionHandler::create_worker_thread(size_t num) {
#ifdef HAVE_NEVERBLEED
nb_.get(),
#endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{memcachedconf.cert_file},
StringRef{memcachedconf.private_key_file}, nullptr);
tlsconf.cacert, memcachedconf.cert_file,
memcachedconf.private_key_file, nullptr);
all_ssl_ctx_.push_back(session_cache_ssl_ctx);
}
auto worker = make_unique<Worker>(
......@@ -835,8 +835,8 @@ SSL_CTX *ConnectionHandler::create_tls_ticket_key_memcached_ssl_ctx() {
#ifdef HAVE_NEVERBLEED
nb_.get(),
#endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{memcachedconf.cert_file},
StringRef{memcachedconf.private_key_file}, nullptr);
tlsconf.cacert, memcachedconf.cert_file, memcachedconf.private_key_file,
nullptr);
all_ssl_ctx_.push_back(ssl_ctx);
......
......@@ -46,12 +46,11 @@ StringRef create_error_html(BlockAllocator &balloc, unsigned int http_status) {
}
auto status_string = http2::get_status_string(balloc, http_status);
const auto &server_name = httpconf.server_name;
return concat_string_ref(
balloc, StringRef::from_lit(R"(<!DOCTYPE html><html lang="en"><title>)"),
status_string, StringRef::from_lit("</title><body><h1>"), status_string,
StringRef::from_lit("</h1><footer>"), StringRef{server_name},
StringRef::from_lit("</h1><footer>"), httpconf.server_name,
StringRef::from_lit("</footer></body></html>"));
}
......
......@@ -357,7 +357,7 @@ int Http2DownstreamConnection::push_request_headers() {
if (xffconf.add) {
StringRef xff_value;
auto addr = StringRef{upstream->get_client_handler()->get_ipaddr()};
const auto &addr = upstream->get_client_handler()->get_ipaddr();
if (xff) {
xff_value = concat_string_ref(balloc, xff->value,
StringRef::from_lit(", "), addr);
......
......@@ -108,15 +108,16 @@ int on_stream_close_callback(nghttp2_session *session, int32_t stream_id,
int Http2Upstream::upgrade_upstream(HttpsUpstream *http) {
int rv;
auto http2_settings = http->get_downstream()->get_http2_settings().str();
util::to_base64(http2_settings);
auto &balloc = http->get_downstream()->get_block_allocator();
auto settings_payload =
base64::decode(std::begin(http2_settings), std::end(http2_settings));
auto http2_settings = http->get_downstream()->get_http2_settings();
http2_settings = util::to_base64(balloc, http2_settings);
auto settings_payload = base64::decode(balloc, std::begin(http2_settings),
std::end(http2_settings));
rv = nghttp2_session_upgrade2(
session_, reinterpret_cast<const uint8_t *>(settings_payload.c_str()),
settings_payload.size(),
session_, settings_payload.byte(), settings_payload.size(),
http->get_downstream()->request().method == HTTP_HEAD, nullptr);
if (rv != 0) {
if (LOG_ENABLED(INFO)) {
......@@ -1429,8 +1430,8 @@ int Http2Upstream::send_reply(Downstream *downstream, const uint8_t *body,
}
if (!resp.fs.header(http2::HD_SERVER)) {
nva.push_back(http2::make_nv_ls_nocopy(
"server", StringRef{get_config()->http.server_name}));
nva.push_back(
http2::make_nv_ls_nocopy("server", get_config()->http.server_name));
}
for (auto &p : httpconf.add_response_headers) {
......@@ -1481,8 +1482,7 @@ int Http2Upstream::error_reply(Downstream *downstream,
auto nva = std::array<nghttp2_nv, 5>{
{http2::make_nv_ls_nocopy(":status", response_status),
http2::make_nv_ll("content-type", "text/html; charset=UTF-8"),
http2::make_nv_ls_nocopy("server",
StringRef{get_config()->http.server_name}),
http2::make_nv_ls_nocopy("server", get_config()->http.server_name),
http2::make_nv_ls_nocopy("content-length", content_length),
http2::make_nv_ls_nocopy("date", date)}};
......@@ -1629,8 +1629,7 @@ int Http2Upstream::on_downstream_header_complete(Downstream *downstream) {
http2::copy_headers_to_nva_nocopy(nva, resp.fs.headers());
if (!get_config()->http2_proxy && !httpconf.no_server_rewrite) {
nva.push_back(
http2::make_nv_ls_nocopy("server", StringRef{httpconf.server_name}));
nva.push_back(http2::make_nv_ls_nocopy("server", httpconf.server_name));
} else {
auto server = resp.fs.header(http2::HD_SERVER);
if (server) {
......
......@@ -960,13 +960,14 @@ std::unique_ptr<Downstream> HttpsUpstream::pop_downstream() {
}
namespace {
void write_altsvc(DefaultMemchunks *buf, const AltSvc &altsvc) {
buf->append(util::percent_encode_token(altsvc.protocol_id));
void write_altsvc(DefaultMemchunks *buf, BlockAllocator &balloc,
const AltSvc &altsvc) {
buf->append(util::percent_encode_token(balloc, altsvc.protocol_id));
buf->append("=\"");
buf->append(util::quote_string(altsvc.host));
buf->append(":");
buf->append(util::quote_string(balloc, altsvc.host));
buf->append(':');
buf->append(altsvc.service);
buf->append("\"");
buf->append('"');
}
} // namespace
......@@ -1073,10 +1074,10 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) {
buf->append("Alt-Svc: ");
auto &altsvcs = httpconf.altsvcs;
write_altsvc(buf, altsvcs[0]);
write_altsvc(buf, downstream->get_block_allocator(), altsvcs[0]);
for (size_t i = 1; i < altsvcs.size(); ++i) {
buf->append(", ");
write_altsvc(buf, altsvcs[i]);
write_altsvc(buf, downstream->get_block_allocator(), altsvcs[i]);
}
buf->append("\r\n");
}
......
......@@ -290,7 +290,7 @@ void upstream_accesslog(const std::vector<LogFragment> &lfv,
break;
case SHRPX_LOGF_HTTP:
if (req) {
auto hd = req->fs.header(StringRef(lf.value));
auto hd = req->fs.header(lf.value);
if (hd) {
std::tie(p, avail) = copy((*hd).value, avail, p);
break;
......
......@@ -137,10 +137,10 @@ enum LogFragmentType {
};
struct LogFragment {
LogFragment(LogFragmentType type, ImmutableString value = ImmutableString())
LogFragment(LogFragmentType type, StringRef value = StringRef::from_lit(""))
: type(type), value(std::move(value)) {}
LogFragmentType type;
ImmutableString value;
StringRef value;
};
struct LogSpec {
......
......@@ -100,7 +100,7 @@ MemcachedConnection::MemcachedConnection(const Address *addr,
connectcb, readcb, timeoutcb, this, 0, 0., PROTO_MEMCACHED),
do_read_(&MemcachedConnection::noop),
do_write_(&MemcachedConnection::noop),
sni_name_(sni_name.str()),
sni_name_(sni_name),
connect_blocker_(gen, loop, [] {}, [] {}),
parse_state_{},
addr_(addr),
......@@ -268,7 +268,7 @@ int MemcachedConnection::tls_handshake() {
auto &tlsconf = get_config()->tls;
if (!tlsconf.insecure &&
ssl::check_cert(conn_.tls.ssl, addr_, StringRef(sni_name_)) != 0) {
ssl::check_cert(conn_.tls.ssl, addr_, sni_name_) != 0) {
connect_blocker_.on_failure();
return -1;
}
......
......@@ -135,7 +135,7 @@ private:
std::deque<std::unique_ptr<MemcachedRequest>> sendq_;
std::deque<MemcachedSendbuf> sendbufv_;
std::function<int(MemcachedConnection &)> do_read_, do_write_;
std::string sni_name_;
StringRef sni_name_;
ssl::TLSSessionCache tls_session_cache_;
ConnectBlocker connect_blocker_;
MemcachedParseState parse_state_;
......
......@@ -103,7 +103,7 @@ int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
} // namespace
int set_alpn_prefs(std::vector<unsigned char> &out,
const std::vector<std::string> &protos) {
const std::vector<StringRef> &protos) {
size_t len = 0;
for (const auto &proto : protos) {
......@@ -125,8 +125,7 @@ int set_alpn_prefs(std::vector<unsigned char> &out,
for (const auto &proto : protos) {
*ptr++ = proto.size();
memcpy(ptr, proto.c_str(), proto.size());
ptr += proto.size();
ptr = std::copy(std::begin(proto), std::end(proto), ptr);
}
return 0;
......@@ -243,6 +242,7 @@ int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) {
auto handler = static_cast<ClientHandler *>(conn->data);
auto worker = handler->get_worker();
auto dispatcher = worker->get_session_cache_memcached_dispatcher();
auto &balloc = handler->get_block_allocator();
const unsigned char *id;
unsigned int idlen;
......@@ -256,7 +256,8 @@ int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) {
auto req = make_unique<MemcachedRequest>();
req->op = MEMCACHED_OP_ADD;
req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str();
req->key += util::format_hex(id, idlen);
req->key +=
util::format_hex(balloc, StringRef{id, static_cast<size_t>(idlen)});
auto sessionlen = i2d_SSL_SESSION(session, nullptr);
req->value.resize(sessionlen);
......@@ -295,6 +296,7 @@ SSL_SESSION *tls_session_get_cb(SSL *ssl,
auto handler = static_cast<ClientHandler *>(conn->data);
auto worker = handler->get_worker();
auto dispatcher = worker->get_session_cache_memcached_dispatcher();
auto &balloc = handler->get_block_allocator();
if (conn->tls.cached_session) {
if (LOG_ENABLED(INFO)) {
......@@ -318,7 +320,8 @@ SSL_SESSION *tls_session_get_cb(SSL *ssl,
auto req = make_unique<MemcachedRequest>();
req->op = MEMCACHED_OP_GET;
req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str();
req->key += util::format_hex(id, idlen);
req->key +=
util::format_hex(balloc, StringRef{id, static_cast<size_t>(idlen)});
req->cb = [conn](MemcachedRequest *, MemcachedResult res) {
if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: returned status code " << res.status_code;
......@@ -465,8 +468,7 @@ int alpn_select_proto_cb(SSL *ssl, const unsigned char **out,
auto proto_len = *p;
if (proto_id + proto_len <= end &&
util::streq(StringRef{target_proto_id},
StringRef{proto_id, proto_len})) {
util::streq(target_proto_id, StringRef{proto_id, proto_len})) {
*out = reinterpret_cast<const unsigned char *>(proto_id);
*outlen = proto_len;
......@@ -493,7 +495,7 @@ constexpr TLSProtocol TLS_PROTOS[] = {
TLSProtocol{StringRef::from_lit("TLSv1.1"), SSL_OP_NO_TLSv1_1},
TLSProtocol{StringRef::from_lit("TLSv1.0"), SSL_OP_NO_TLSv1}};
long int create_tls_proto_mask(const std::vector<std::string> &tls_proto_list) {
long int create_tls_proto_mask(const std::vector<StringRef> &tls_proto_list) {
long int res = 0;
for (auto &supported : TLS_PROTOS) {
......@@ -829,16 +831,16 @@ SSL *create_ssl(SSL_CTX *ssl_ctx) {
ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
int addrlen, const UpstreamAddr *faddr) {
char host[NI_MAXHOST];
char service[NI_MAXSERV];
std::array<char, NI_MAXHOST> host;
std::array<char, NI_MAXSERV> service;
int rv;
if (addr->sa_family == AF_UNIX) {
std::copy_n("localhost", sizeof("localhost"), host);
std::copy_n("localhost", sizeof("localhost"), std::begin(host));
service[0] = '\0';
} else {
rv = getnameinfo(addr, addrlen, host, sizeof(host), service,
sizeof(service), NI_NUMERICHOST | NI_NUMERICSERV);
rv = getnameinfo(addr, addrlen, host.data(), host.size(), service.data(),
service.size(), NI_NUMERICHOST | NI_NUMERICSERV);
if (rv != 0) {
LOG(ERROR) << "getnameinfo() failed: " << gai_strerror(rv);
......@@ -867,8 +869,8 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
}
}
return new ClientHandler(worker, fd, ssl, host, service, addr->sa_family,
faddr);
return new ClientHandler(worker, fd, ssl, StringRef{host.data()},
StringRef{service.data()}, addr->sa_family, faddr);
}
bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname) {
......@@ -1316,10 +1318,10 @@ int cert_lookup_tree_add_cert_from_x509(CertLookupTree *lt, size_t idx,
return 0;
}
bool in_proto_list(const std::vector<std::string> &protos,
bool in_proto_list(const std::vector<StringRef> &protos,
const StringRef &needle) {
for (auto &proto : protos) {
if (util::streq(StringRef{proto}, needle)) {
if (util::streq(proto, needle)) {
return true;
}
}
......@@ -1443,8 +1445,8 @@ SSL_CTX *setup_downstream_client_ssl_context(
#ifdef HAVE_NEVERBLEED
nb,
#endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{tlsconf.client.cert_file},
StringRef{tlsconf.client.private_key_file}, select_next_proto_cb);
tlsconf.cacert, tlsconf.client.cert_file, tlsconf.client.private_key_file,
select_next_proto_cb);
}
void setup_downstream_http2_alpn(SSL *ssl) {
......
......@@ -101,11 +101,6 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
int check_cert(SSL *ssl, const Address *addr, const StringRef &host);
int check_cert(SSL *ssl, const DownstreamAddr *addr);
// Retrieves DNS and IP address in subjectAltNames and commonName from
// the |cert|.
void get_altnames(X509 *cert, std::vector<std::string> &dns_names,
std::vector<std::string> &ip_addrs, std::string &common_name);
struct WildcardRevPrefix {
WildcardRevPrefix(const StringRef &prefix, size_t idx)
: prefix(std::begin(prefix), std::end(prefix)), idx(idx) {}
......@@ -172,7 +167,7 @@ int cert_lookup_tree_add_cert_from_x509(CertLookupTree *lt, size_t idx,
// Returns true if |proto| is included in the
// protocol list |protos|.
bool in_proto_list(const std::vector<std::string> &protos,
bool in_proto_list(const std::vector<StringRef> &protos,
const StringRef &proto);
// Returns true if security requirement for HTTP/2 is fulfilled.
......@@ -181,10 +176,10 @@ bool check_http2_requirement(SSL *ssl);
// Returns SSL/TLS option mask to disable SSL/TLS protocol version not
// included in |tls_proto_list|. The returned mask can be directly
// passed to SSL_CTX_set_options().
long int create_tls_proto_mask(const std::vector<std::string> &tls_proto_list);
long int create_tls_proto_mask(const std::vector<StringRef> &tls_proto_list);
int set_alpn_prefs(std::vector<unsigned char> &out,
const std::vector<std::string> &protos);
const std::vector<StringRef> &protos);
// Setups server side SSL_CTX. This function inspects get_config()
// and if upstream_no_tls is true, returns nullptr. Otherwise
......
......@@ -182,7 +182,8 @@ void Worker::replace_downstream_config(
auto &dst = downstream_addr_groups_[i];
dst = std::make_shared<DownstreamAddrGroup>();
dst->pattern = src.pattern;
dst->pattern =
ImmutableString{std::begin(src.pattern), std::end(src.pattern)};
auto shared_addr = std::make_shared<SharedDownstreamAddr>();
......@@ -198,13 +199,14 @@ void Worker::replace_downstream_config(
auto &dst_addr = shared_addr->addrs[j];
dst_addr.addr = src_addr.addr;
dst_addr.host = src_addr.host;
dst_addr.hostport = src_addr.hostport;
dst_addr.host = make_string_ref(shared_addr->balloc, src_addr.host);
dst_addr.hostport =
make_string_ref(shared_addr->balloc, src_addr.hostport);
dst_addr.port = src_addr.port;
dst_addr.host_unix = src_addr.host_unix;
dst_addr.proto = src_addr.proto;
dst_addr.tls = src_addr.tls;
dst_addr.sni = src_addr.sni;
dst_addr.sni = make_string_ref(shared_addr->balloc, src_addr.sni);
dst_addr.fall = src_addr.fall;
dst_addr.rise = src_addr.rise;
......
......@@ -48,6 +48,7 @@
#include "shrpx_ssl.h"
#include "shrpx_live_check.h"
#include "shrpx_connect_blocker.h"
#include "allocator.h"
using namespace nghttp2;
......@@ -75,15 +76,15 @@ struct DownstreamAddr {
Address addr;
// backend address. If |host_unix| is true, this is UNIX domain
// socket path.
ImmutableString host;
ImmutableString hostport;
StringRef host;
StringRef hostport;
// backend port. 0 if |host_unix| is true.
uint16_t port;
// true if |host| contains UNIX domain socket path.
bool host_unix;
// sni field to send remote server if TLS is enabled.
ImmutableString sni;
StringRef sni;
std::unique_ptr<ConnectBlocker> connect_blocker;
std::unique_ptr<LiveCheck> live_check;
......@@ -128,8 +129,18 @@ struct WeightedPri {
struct SharedDownstreamAddr {
SharedDownstreamAddr()
: next{0}, http1_pri{}, http2_pri{}, affinity{AFFINITY_NONE} {}
: balloc(1024, 1024),
next{0},
http1_pri{},
http2_pri{},
affinity{AFFINITY_NONE} {}
SharedDownstreamAddr(const SharedDownstreamAddr &) = delete;
SharedDownstreamAddr(SharedDownstreamAddr &&) = delete;
SharedDownstreamAddr &operator=(const SharedDownstreamAddr &) = delete;
SharedDownstreamAddr &operator=(SharedDownstreamAddr &&) = delete;
BlockAllocator balloc;
std::vector<DownstreamAddr> addrs;
// Bunch of session affinity hash. Only used if affinity ==
// AFFINITY_IP.
......@@ -162,6 +173,11 @@ struct SharedDownstreamAddr {
struct DownstreamAddrGroup {
DownstreamAddrGroup() : retired{false} {};
DownstreamAddrGroup(const DownstreamAddrGroup &) = delete;
DownstreamAddrGroup(DownstreamAddrGroup &&) = delete;
DownstreamAddrGroup &operator=(const DownstreamAddrGroup &) = delete;
DownstreamAddrGroup &operator=(DownstreamAddrGroup &&) = delete;
ImmutableString pattern;
std::shared_ptr<SharedDownstreamAddr> shared_addr;
// true if this group is no longer used for new request. If this is
......
......@@ -131,11 +131,10 @@ bool in_attr_char(char c) {
std::find(std::begin(bad), std::end(bad), c) == std::end(bad);
}
std::string percent_encode_token(const std::string &target) {
std::string dest;
dest.resize(target.size() * 3);
auto p = std::begin(dest);
StringRef percent_encode_token(BlockAllocator &balloc,
const StringRef &target) {
auto iov = make_byte_ref(balloc, target.size() * 3 + 1);
auto p = iov.base;
for (auto first = std::begin(target); first != std::end(target); ++first) {
uint8_t c = *first;
......@@ -149,8 +148,10 @@ std::string percent_encode_token(const std::string &target) {
*p++ = UPPER_XDIGITS[c >> 4];
*p++ = UPPER_XDIGITS[(c & 0x0f)];
}
dest.resize(p - std::begin(dest));
return dest;
*p = '\0';
return StringRef{iov.base, p};
}
uint32_t hex_to_uint(char c) {
......@@ -166,25 +167,27 @@ uint32_t hex_to_uint(char c) {
return c;
}
std::string quote_string(const std::string &target) {
StringRef quote_string(BlockAllocator &balloc, const StringRef &target) {
auto cnt = std::count(std::begin(target), std::end(target), '"');
if (cnt == 0) {
return target;
return make_string_ref(balloc, target);
}
std::string res;
res.reserve(target.size() + cnt);
auto iov = make_byte_ref(balloc, target.size() + cnt + 1);
auto p = iov.base;
for (auto c : target) {
if (c == '"') {
res += "\\\"";
*p++ = '\\';
*p++ = '"';
} else {
res += c;
*p++ = c;
}
}
*p = '\0';
return res;
return StringRef{iov.base, p};
}
namespace {
......@@ -376,6 +379,21 @@ std::string format_hex(const unsigned char *s, size_t len) {
return res;
}
StringRef format_hex(BlockAllocator &balloc, const StringRef &s) {
auto iov = make_byte_ref(balloc, s.size() * 2 + 1);
auto p = iov.base;
for (auto cc : s) {
uint8_t c = cc;
*p++ = LOWER_XDIGITS[c >> 4];
*p++ = LOWER_XDIGITS[c & 0xf];
}
*p = '\0';
return StringRef{iov.base, p};
}
void to_token68(std::string &base64str) {
std::transform(std::begin(base64str), std::end(base64str),
std::begin(base64str), [](char c) {
......@@ -392,22 +410,32 @@ void to_token68(std::string &base64str) {
std::end(base64str));
}
void to_base64(std::string &token68str) {
std::transform(std::begin(token68str), std::end(token68str),
std::begin(token68str), [](char c) {
switch (c) {
case '-':
return '+';
case '_':
return '/';
default:
return c;
}
});
if (token68str.size() & 0x3) {
token68str.append(4 - (token68str.size() & 0x3), '=');
StringRef to_base64(BlockAllocator &balloc, const StringRef &token68str) {
// At most 3 padding '='
auto len = token68str.size() + 3;
auto iov = make_byte_ref(balloc, len + 1);
auto p = iov.base;
p = std::transform(std::begin(token68str), std::end(token68str), p,
[](char c) {
switch (c) {
case '-':
return '+';
case '_':
return '/';
default:
return c;
}
});
auto rem = token68str.size() & 0x3;
if (rem) {
p = std::fill_n(p, 4 - rem, '=');
}
return;
*p = '\0';
return StringRef{iov.base, p};
}
namespace {
......@@ -1119,29 +1147,30 @@ std::string dtos(double n) {
return utos(static_cast<int64_t>(n)) + "." + (f.size() == 1 ? "0" : "") + f;
}
std::string make_http_hostport(const StringRef &host, uint16_t port) {
StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port) {
if (port != 80 && port != 443) {
return make_hostport(host, port);
return make_hostport(balloc, host, port);
}
auto ipv6 = ipv6_numeric_addr(host.c_str());
std::string hostport;
hostport.resize(host.size() + (ipv6 ? 2 : 0));
auto p = &hostport[0];
auto iov = make_byte_ref(balloc, host.size() + (ipv6 ? 2 : 0) + 1);
auto p = iov.base;
if (ipv6) {
*p++ = '[';
}
p = std::copy_n(host.c_str(), host.size(), p);
p = std::copy(std::begin(host), std::end(host), p);
if (ipv6) {
*p++ = ']';
}
return hostport;
*p = '\0';
return StringRef{iov.base, p};
}
std::string make_hostport(const StringRef &host, uint16_t port) {
......@@ -1169,6 +1198,34 @@ std::string make_hostport(const StringRef &host, uint16_t port) {
return hostport;
}
StringRef make_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port) {
auto ipv6 = ipv6_numeric_addr(host.c_str());
auto serv = utos(port);
auto iov =
make_byte_ref(balloc, host.size() + (ipv6 ? 2 : 0) + 1 + serv.size());
auto p = iov.base;
if (ipv6) {
*p++ = '[';
}
p = std::copy(std::begin(host), std::end(host), p);
if (ipv6) {
*p++ = ']';
}
*p++ = ':';
p = std::copy(std::begin(serv), std::end(serv), p);
*p = '\0';
return StringRef{iov.base, p};
}
namespace {
void hexdump8(FILE *out, const uint8_t *first, const uint8_t *last) {
auto stop = std::min(first + 8, last);
......
......@@ -129,11 +129,11 @@ std::string percent_decode(InputIt first, InputIt last) {
StringRef percent_decode(BlockAllocator &balloc, const StringRef &src);
// Percent encode |target| if character is not in token or '%'.
std::string percent_encode_token(const std::string &target);
StringRef percent_encode_token(BlockAllocator &balloc, const StringRef &target);
// Returns quotedString version of |target|. Currently, this function
// just replace '"' with '\"'.
std::string quote_string(const std::string &target);
StringRef quote_string(BlockAllocator &balloc, const StringRef &target);
std::string format_hex(const unsigned char *s, size_t len);
......@@ -145,6 +145,8 @@ template <size_t N> std::string format_hex(const std::array<uint8_t, N> &s) {
return format_hex(s.data(), s.size());
}
StringRef format_hex(BlockAllocator &balloc, const StringRef &s);
std::string http_date(time_t t);
// Returns given time |t| from epoch in Common Log format (e.g.,
......@@ -424,7 +426,8 @@ template <typename T> std::string utox(T n) {
}
void to_token68(std::string &base64str);
void to_base64(std::string &token68str);
StringRef to_base64(BlockAllocator &balloc, const StringRef &token68str);
void show_candidates(const char *unkopt, option *options);
......@@ -630,12 +633,16 @@ std::string format_duration(double t);
// Creates "host:port" string using given |host| and |port|. If
// |host| is numeric IPv6 address (e.g., ::1), it is enclosed by "["
// and "]". If |port| is 80 or 443, port part is omitted.
std::string make_http_hostport(const StringRef &host, uint16_t port);
StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port);
// Just like make_http_hostport(), but doesn't treat 80 and 443
// specially.
std::string make_hostport(const StringRef &host, uint16_t port);
StringRef make_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port);
// Dumps |src| of length |len| in the format similar to `hexdump -C`.
void hexdump(FILE *out, const uint8_t *src, size_t len);
......@@ -665,16 +672,17 @@ uint64_t get_uint64(const uint8_t *data);
int read_mime_types(std::map<std::string, std::string> &res,
const char *filename);
template <typename Generator>
std::string random_alpha_digit(Generator &gen, size_t len) {
std::string res;
res.reserve(len);
// Fills random alpha and digit byte to the range [|first|, |last|).
// Returns the one beyond the |last|.
template <typename OutputIt, typename Generator>
OutputIt random_alpha_digit(OutputIt first, OutputIt last, Generator &gen) {
constexpr uint8_t s[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
std::uniform_int_distribution<> dis(0, 26 * 2 + 10 - 1);
for (; len > 0; --len) {
res += "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[dis(
gen)];
for (; first != last; ++first) {
*first = s[dis(gen)];
}
return res;
return first;
}
template <typename OutputIterator, typename CharT, size_t N>
......
......@@ -26,6 +26,7 @@
#include <cstring>
#include <iostream>
#include <random>
#include <CUnit/CUnit.h>
......@@ -113,13 +114,12 @@ void test_util_inp_strlower(void) {
}
void test_util_to_base64(void) {
std::string x = "AAA--B_";
util::to_base64(x);
CU_ASSERT("AAA++B/=" == x);
BlockAllocator balloc(4096, 4096);
x = "AAA--B_B";
util::to_base64(x);
CU_ASSERT("AAA++B/B" == x);
CU_ASSERT("AAA++B/=" ==
util::to_base64(balloc, StringRef::from_lit("AAA--B_")));
CU_ASSERT("AAA++B/B" ==
util::to_base64(balloc, StringRef::from_lit("AAA--B_B")));
}
void test_util_to_token68(void) {
......@@ -133,10 +133,15 @@ void test_util_to_token68(void) {
}
void test_util_percent_encode_token(void) {
CU_ASSERT("h2" == util::percent_encode_token("h2"));
CU_ASSERT("h3~" == util::percent_encode_token("h3~"));
CU_ASSERT("100%25" == util::percent_encode_token("100%"));
CU_ASSERT("http%202" == util::percent_encode_token("http 2"));
BlockAllocator balloc(4096, 4096);
CU_ASSERT("h2" ==
util::percent_encode_token(balloc, StringRef::from_lit("h2")));
CU_ASSERT("h3~" ==
util::percent_encode_token(balloc, StringRef::from_lit("h3~")));
CU_ASSERT("100%25" ==
util::percent_encode_token(balloc, StringRef::from_lit("100%")));
CU_ASSERT("http%202" ==
util::percent_encode_token(balloc, StringRef::from_lit("http 2")));
}
void test_util_percent_encode_path(void) {
......@@ -169,9 +174,12 @@ void test_util_percent_decode(void) {
}
void test_util_quote_string(void) {
CU_ASSERT("alpha" == util::quote_string("alpha"));
CU_ASSERT("" == util::quote_string(""));
CU_ASSERT("\\\"alpha\\\"" == util::quote_string("\"alpha\""));
BlockAllocator balloc(4096, 4096);
CU_ASSERT("alpha" ==
util::quote_string(balloc, StringRef::from_lit("alpha")));
CU_ASSERT("" == util::quote_string(balloc, StringRef::from_lit("")));
CU_ASSERT("\\\"alpha\\\"" ==
util::quote_string(balloc, StringRef::from_lit("\"alpha\"")));
}
void test_util_utox(void) {
......@@ -494,12 +502,15 @@ void test_util_parse_config_str_list(void) {
}
void test_util_make_http_hostport(void) {
CU_ASSERT("localhost" ==
util::make_http_hostport(StringRef::from_lit("localhost"), 80));
BlockAllocator balloc(4096, 4096);
CU_ASSERT("localhost" == util::make_http_hostport(
balloc, StringRef::from_lit("localhost"), 80));
CU_ASSERT("[::1]" ==
util::make_http_hostport(StringRef::from_lit("::1"), 443));
CU_ASSERT("localhost:3000" ==
util::make_http_hostport(StringRef::from_lit("localhost"), 3000));
util::make_http_hostport(balloc, StringRef::from_lit("::1"), 443));
CU_ASSERT(
"localhost:3000" ==
util::make_http_hostport(balloc, StringRef::from_lit("localhost"), 3000));
}
void test_util_make_hostport(void) {
......@@ -507,6 +518,12 @@ void test_util_make_hostport(void) {
util::make_hostport(StringRef::from_lit("localhost"), 80));
CU_ASSERT("[::1]:443" ==
util::make_hostport(StringRef::from_lit("::1"), 443));
BlockAllocator balloc(4096, 4096);
CU_ASSERT("localhost:80" ==
util::make_hostport(balloc, StringRef::from_lit("localhost"), 80));
CU_ASSERT("[::1]:443" ==
util::make_hostport(balloc, StringRef::from_lit("::1"), 443));
}
void test_util_strifind(void) {
......@@ -528,4 +545,27 @@ void test_util_strifind(void) {
StringRef::from_lit("http1")));
}
void test_util_random_alpha_digit(void) {
std::random_device rd;
std::mt19937 gen(rd());
std::array<uint8_t, 19> data;
auto p = util::random_alpha_digit(std::begin(data), std::end(data), gen);
CU_ASSERT(std::end(data) == p);
for (auto b : data) {
CU_ASSERT(('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z') ||
('0' <= b && b <= '9'));
}
}
void test_util_format_hex(void) {
BlockAllocator balloc(4096, 4096);
CU_ASSERT("0ff0" ==
util::format_hex(balloc, StringRef::from_lit("\x0f\xf0")));
CU_ASSERT("" == util::format_hex(balloc, StringRef::from_lit("")));
}
} // namespace shrpx
......@@ -62,6 +62,8 @@ void test_util_parse_config_str_list(void);
void test_util_make_http_hostport(void);
void test_util_make_hostport(void);
void test_util_strifind(void);
void test_util_random_alpha_digit(void);
void test_util_format_hex(void);
} // namespace shrpx
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment