Commit 35594e09 authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

Merge branch 'nghttpx-more-block-allocator'

parents 68a6d8c5 96ff3be5
...@@ -29,6 +29,9 @@ ...@@ -29,6 +29,9 @@
#include <string> #include <string>
#include "template.h"
#include "allocator.h"
namespace nghttp2 { namespace nghttp2 {
namespace base64 { namespace base64 {
...@@ -87,7 +90,8 @@ InputIt next_decode_input(InputIt first, InputIt last, const int *tbl) { ...@@ -87,7 +90,8 @@ InputIt next_decode_input(InputIt first, InputIt last, const int *tbl) {
return first; return first;
} }
template <typename InputIt> std::string decode(InputIt first, InputIt last) { template <typename InputIt, typename OutputIt>
OutputIt decode(InputIt first, InputIt last, OutputIt d_first) {
static constexpr int INDEX_TABLE[] = { static constexpr int INDEX_TABLE[] = {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
...@@ -104,37 +108,29 @@ template <typename InputIt> std::string decode(InputIt first, InputIt last) { ...@@ -104,37 +108,29 @@ template <typename InputIt> std::string decode(InputIt first, InputIt last) {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1}; -1, -1, -1, -1};
auto len = last - first; assert(std::distance(first, last) % 4 == 0);
if (len % 4 != 0) { auto p = d_first;
return "";
}
std::string res;
res.resize(len / 4 * 3);
auto p = std::begin(res);
for (; first != last;) { for (; first != last;) {
uint32_t n = 0; uint32_t n = 0;
for (int i = 1; i <= 4; ++i, ++first) { for (int i = 1; i <= 4; ++i, ++first) {
auto idx = INDEX_TABLE[static_cast<size_t>(*first)]; auto idx = INDEX_TABLE[static_cast<size_t>(*first)];
if (idx == -1) { if (idx == -1) {
if (i <= 2) { if (i <= 2) {
return ""; return d_first;
} }
if (i == 3) { if (i == 3) {
if (*first == '=' && *(first + 1) == '=' && first + 2 == last) { if (*first == '=' && *(first + 1) == '=' && first + 2 == last) {
*p++ = n >> 16; *p++ = n >> 16;
res.resize(p - std::begin(res)); return p;
return res;
} }
return ""; return d_first;
} }
if (*first == '=' && first + 1 == last) { if (*first == '=' && first + 1 == last) {
*p++ = n >> 16; *p++ = n >> 16;
*p++ = n >> 8 & 0xffu; *p++ = n >> 8 & 0xffu;
res.resize(p - std::begin(res)); return p;
return res;
} }
return ""; return d_first;
} }
n += idx << (24 - i * 6); n += idx << (24 - i * 6);
...@@ -145,9 +141,37 @@ template <typename InputIt> std::string decode(InputIt first, InputIt last) { ...@@ -145,9 +141,37 @@ template <typename InputIt> std::string decode(InputIt first, InputIt last) {
*p++ = n & 0xffu; *p++ = n & 0xffu;
} }
return p;
}
template <typename InputIt> std::string decode(InputIt first, InputIt last) {
auto len = std::distance(first, last);
if (len % 4 != 0) {
return "";
}
std::string res;
res.resize(len / 4 * 3);
res.erase(decode(first, last, std::begin(res)), std::end(res));
return res; return res;
} }
template <typename InputIt>
StringRef decode(BlockAllocator &balloc, InputIt first, InputIt last) {
auto len = std::distance(first, last);
if (len % 4 != 0) {
return StringRef::from_lit("");
}
auto iov = make_byte_ref(balloc, len / 4 * 3 + 1);
auto p = iov.base;
p = decode(first, last, p);
*p = '\0';
return StringRef{iov.base, p};
}
} // namespace base64 } // namespace base64
} // namespace nghttp2 } // namespace nghttp2
......
...@@ -59,31 +59,40 @@ void test_base64_encode(void) { ...@@ -59,31 +59,40 @@ void test_base64_encode(void) {
} }
void test_base64_decode(void) { void test_base64_decode(void) {
BlockAllocator balloc(4096, 4096);
{ {
std::string in = "/w=="; std::string in = "/w==";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff" == out); CU_ASSERT("\xff" == out);
CU_ASSERT("\xff" == base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
std::string in = "//4="; std::string in = "//4=";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff\xfe" == out); CU_ASSERT("\xff\xfe" == out);
CU_ASSERT("\xff\xfe" ==
base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
std::string in = "//79"; std::string in = "//79";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff\xfe\xfd" == out); CU_ASSERT("\xff\xfe\xfd" == out);
CU_ASSERT("\xff\xfe\xfd" ==
base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
std::string in = "//79/A=="; std::string in = "//79/A==";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("\xff\xfe\xfd\xfc" == out); CU_ASSERT("\xff\xfe\xfd\xfc" == out);
CU_ASSERT("\xff\xfe\xfd\xfc" ==
base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
// we check the number of valid input must be multiples of 4 // we check the number of valid input must be multiples of 4
std::string in = "//79="; std::string in = "//79=";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out); CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
// ending invalid character at the boundary of multiples of 4 is // ending invalid character at the boundary of multiples of 4 is
...@@ -91,18 +100,21 @@ void test_base64_decode(void) { ...@@ -91,18 +100,21 @@ void test_base64_decode(void) {
std::string in = "bmdodHRw\n"; std::string in = "bmdodHRw\n";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out); CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
// after seeing '=', subsequent input must be also '='. // after seeing '=', subsequent input must be also '='.
std::string in = "//79/A=A"; std::string in = "//79/A=A";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out); CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
} }
{ {
// additional '=' at the end is bad // additional '=' at the end is bad
std::string in = "//79/A======"; std::string in = "//79/A======";
auto out = base64::decode(std::begin(in), std::end(in)); auto out = base64::decode(std::begin(in), std::end(in));
CU_ASSERT("" == out); CU_ASSERT("" == out);
CU_ASSERT("" == base64::decode(balloc, std::begin(in), std::end(in)));
} }
} }
......
...@@ -176,6 +176,9 @@ int main(int argc, char *argv[]) { ...@@ -176,6 +176,9 @@ int main(int argc, char *argv[]) {
!CU_add_test(pSuite, "util_make_hostport", !CU_add_test(pSuite, "util_make_hostport",
shrpx::test_util_make_hostport) || shrpx::test_util_make_hostport) ||
!CU_add_test(pSuite, "util_strifind", shrpx::test_util_strifind) || !CU_add_test(pSuite, "util_strifind", shrpx::test_util_strifind) ||
!CU_add_test(pSuite, "util_random_alpha_digit",
shrpx::test_util_random_alpha_digit) ||
!CU_add_test(pSuite, "util_format_hex", shrpx::test_util_format_hex) ||
!CU_add_test(pSuite, "gzip_inflate", test_nghttp2_gzip_inflate) || !CU_add_test(pSuite, "gzip_inflate", test_nghttp2_gzip_inflate) ||
!CU_add_test(pSuite, "buffer_write", nghttp2::test_buffer_write) || !CU_add_test(pSuite, "buffer_write", nghttp2::test_buffer_write) ||
!CU_add_test(pSuite, "pool_recycle", nghttp2::test_pool_recycle) || !CU_add_test(pSuite, "pool_recycle", nghttp2::test_pool_recycle) ||
......
...@@ -86,6 +86,7 @@ ...@@ -86,6 +86,7 @@
#include "app_helper.h" #include "app_helper.h"
#include "ssl.h" #include "ssl.h"
#include "template.h" #include "template.h"
#include "allocator.h"
extern char **environ; extern char **environ;
...@@ -151,7 +152,7 @@ StartupConfig suconfig; ...@@ -151,7 +152,7 @@ StartupConfig suconfig;
struct InheritedAddr { struct InheritedAddr {
// IP address if TCP socket. Otherwise, UNIX domain socket path. // IP address if TCP socket. Otherwise, UNIX domain socket path.
ImmutableString host; StringRef host;
uint16_t port; uint16_t port;
// true if UNIX domain socket path // true if UNIX domain socket path
bool host_unix; bool host_unix;
...@@ -574,7 +575,7 @@ int create_unix_domain_server_socket(UpstreamAddr &faddr, ...@@ -574,7 +575,7 @@ int create_unix_domain_server_socket(UpstreamAddr &faddr,
<< (faddr.tls ? ", tls" : ""); << (faddr.tls ? ", tls" : "");
(*found).used = true; (*found).used = true;
faddr.fd = (*found).fd; faddr.fd = (*found).fd;
faddr.hostport = ImmutableString::from_lit("localhost"); faddr.hostport = StringRef::from_lit("localhost");
return 0; return 0;
} }
...@@ -639,7 +640,7 @@ int create_unix_domain_server_socket(UpstreamAddr &faddr, ...@@ -639,7 +640,7 @@ int create_unix_domain_server_socket(UpstreamAddr &faddr,
<< (faddr.tls ? ", tls" : ""); << (faddr.tls ? ", tls" : "");
faddr.fd = fd; faddr.fd = fd;
faddr.hostport = ImmutableString::from_lit("localhost"); faddr.hostport = StringRef::from_lit("localhost");
return 0; return 0;
} }
...@@ -791,8 +792,8 @@ int create_tcp_server_socket(UpstreamAddr &faddr, ...@@ -791,8 +792,8 @@ int create_tcp_server_socket(UpstreamAddr &faddr,
} }
faddr.fd = fd; faddr.fd = fd;
faddr.hostport = ImmutableString{ faddr.hostport = util::make_http_hostport(mod_config()->balloc,
util::make_http_hostport(StringRef{host.data()}, faddr.port)}; StringRef{host.data()}, faddr.port);
LOG(NOTICE) << "Listening on " << faddr.hostport LOG(NOTICE) << "Listening on " << faddr.hostport
<< (faddr.tls ? ", tls" : ""); << (faddr.tls ? ", tls" : "");
...@@ -806,7 +807,7 @@ namespace { ...@@ -806,7 +807,7 @@ namespace {
// function is intended to be used when reloading configuration, and // function is intended to be used when reloading configuration, and
// |config| is usually a current configuration. // |config| is usually a current configuration.
std::vector<InheritedAddr> std::vector<InheritedAddr>
get_inherited_addr_from_config(const Config *config) { get_inherited_addr_from_config(BlockAllocator &balloc, Config *config) {
int rv; int rv;
auto &listenerconf = config->conn.listener; auto &listenerconf = config->conn.listener;
...@@ -856,7 +857,7 @@ get_inherited_addr_from_config(const Config *config) { ...@@ -856,7 +857,7 @@ get_inherited_addr_from_config(const Config *config) {
continue; continue;
} }
iaddr.host = ImmutableString{host.data()}; iaddr.host = make_string_ref(balloc, StringRef{host.data()});
} }
return iaddrs; return iaddrs;
...@@ -867,7 +868,7 @@ namespace { ...@@ -867,7 +868,7 @@ namespace {
// Returns array of InheritedAddr constructed from environment // Returns array of InheritedAddr constructed from environment
// variables. This function handles the old environment variable // variables. This function handles the old environment variable
// names used in 1.7.0 or earlier. // names used in 1.7.0 or earlier.
std::vector<InheritedAddr> get_inherited_addr_from_env() { std::vector<InheritedAddr> get_inherited_addr_from_env(Config *config) {
int rv; int rv;
std::vector<InheritedAddr> iaddrs; std::vector<InheritedAddr> iaddrs;
...@@ -888,15 +889,14 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() { ...@@ -888,15 +889,14 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() {
} }
} }
} else { } else {
auto pathenv = getenv(ENV_UNIX_PATH.c_str()); // The return value of getenv may be allocated statically.
auto fdenv = getenv(ENV_UNIX_FD.c_str()); if (getenv(ENV_UNIX_PATH.c_str()) && getenv(ENV_UNIX_FD.c_str())) {
if (pathenv && fdenv) {
auto name = ENV_ACCEPT_PREFIX.str(); auto name = ENV_ACCEPT_PREFIX.str();
name += '1'; name += '1';
std::string value = "unix,"; std::string value = "unix,";
value += fdenv; value += getenv(ENV_UNIX_FD.c_str());
value += ','; value += ',';
value += pathenv; value += getenv(ENV_UNIX_PATH.c_str());
setenv(name.c_str(), value.c_str(), 0); setenv(name.c_str(), value.c_str(), 0);
} }
} }
...@@ -948,7 +948,7 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() { ...@@ -948,7 +948,7 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() {
} }
InheritedAddr addr{}; InheritedAddr addr{};
addr.host = ImmutableString{path}; addr.host = make_string_ref(config->balloc, StringRef{path});
addr.host_unix = true; addr.host_unix = true;
addr.fd = static_cast<int>(fd); addr.fd = static_cast<int>(fd);
iaddrs.push_back(std::move(addr)); iaddrs.push_back(std::move(addr));
...@@ -1002,7 +1002,7 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() { ...@@ -1002,7 +1002,7 @@ std::vector<InheritedAddr> get_inherited_addr_from_env() {
} }
InheritedAddr addr{}; InheritedAddr addr{};
addr.host = ImmutableString{host.data()}; addr.host = make_string_ref(config->balloc, StringRef{host.data()});
addr.port = static_cast<uint16_t>(port); addr.port = static_cast<uint16_t>(port);
addr.fd = static_cast<int>(fd); addr.fd = static_cast<int>(fd);
iaddrs.push_back(std::move(addr)); iaddrs.push_back(std::move(addr));
...@@ -1209,7 +1209,7 @@ int event_loop() { ...@@ -1209,7 +1209,7 @@ int event_loop() {
redirect_stderr_to_errorlog(); redirect_stderr_to_errorlog();
} }
auto iaddrs = get_inherited_addr_from_env(); auto iaddrs = get_inherited_addr_from_env(mod_config());
if (create_acceptor_socket(mod_config(), iaddrs) != 0) { if (create_acceptor_socket(mod_config(), iaddrs) != 0) {
return -1; return -1;
...@@ -1274,7 +1274,7 @@ constexpr auto DEFAULT_ACCESSLOG_FORMAT = StringRef::from_lit( ...@@ -1274,7 +1274,7 @@ constexpr auto DEFAULT_ACCESSLOG_FORMAT = StringRef::from_lit(
namespace { namespace {
void fill_default_config(Config *config) { void fill_default_config(Config *config) {
config->num_worker = 1; config->num_worker = 1;
config->conf_path = ImmutableString::from_lit("/etc/nghttpx/nghttpx.conf"); config->conf_path = StringRef::from_lit("/etc/nghttpx/nghttpx.conf");
config->pid = getpid(); config->pid = getpid();
if (ev_supported_backends() & ~ev_recommended_backends() & EVBACKEND_KQUEUE) { if (ev_supported_backends() & ~ev_recommended_backends() & EVBACKEND_KQUEUE) {
...@@ -1306,7 +1306,7 @@ void fill_default_config(Config *config) { ...@@ -1306,7 +1306,7 @@ void fill_default_config(Config *config) {
// ocsp update interval = 14400 secs = 4 hours, borrowed from h2o // ocsp update interval = 14400 secs = 4 hours, borrowed from h2o
ocspconf.update_interval = 4_h; ocspconf.update_interval = 4_h;
ocspconf.fetch_ocsp_response_file = ocspconf.fetch_ocsp_response_file =
ImmutableString::from_lit(PKGDATADIR "/fetch-ocsp-response"); StringRef::from_lit(PKGDATADIR "/fetch-ocsp-response");
} }
{ {
...@@ -1319,7 +1319,7 @@ void fill_default_config(Config *config) { ...@@ -1319,7 +1319,7 @@ void fill_default_config(Config *config) {
auto &httpconf = config->http; auto &httpconf = config->http;
httpconf.server_name = httpconf.server_name =
ImmutableString::from_lit("nghttpx nghttp2/" NGHTTP2_VERSION); StringRef::from_lit("nghttpx nghttp2/" NGHTTP2_VERSION);
httpconf.no_host_rewrite = true; httpconf.no_host_rewrite = true;
httpconf.request_header_field_buffer = 64_k; httpconf.request_header_field_buffer = 64_k;
httpconf.max_request_header_fields = 100; httpconf.max_request_header_fields = 100;
...@@ -1387,10 +1387,11 @@ void fill_default_config(Config *config) { ...@@ -1387,10 +1387,11 @@ void fill_default_config(Config *config) {
auto &loggingconf = config->logging; auto &loggingconf = config->logging;
{ {
auto &accessconf = loggingconf.access; auto &accessconf = loggingconf.access;
accessconf.format = parse_log_format(DEFAULT_ACCESSLOG_FORMAT); accessconf.format =
parse_log_format(config->balloc, DEFAULT_ACCESSLOG_FORMAT);
auto &errorconf = loggingconf.error; auto &errorconf = loggingconf.error;
errorconf.file = ImmutableString::from_lit("/dev/stderr"); errorconf.file = StringRef::from_lit("/dev/stderr");
} }
loggingconf.syslog_facility = LOG_DAEMON; loggingconf.syslog_facility = LOG_DAEMON;
...@@ -2446,11 +2447,10 @@ int process_options(Config *config, ...@@ -2446,11 +2447,10 @@ int process_options(Config *config,
auto &tlsconf = config->tls; auto &tlsconf = config->tls;
if (tlsconf.npn_list.empty()) { if (tlsconf.npn_list.empty()) {
tlsconf.npn_list = util::parse_config_str_list(DEFAULT_NPN_LIST); tlsconf.npn_list = util::split_str(DEFAULT_NPN_LIST, ',');
} }
if (tlsconf.tls_proto_list.empty()) { if (tlsconf.tls_proto_list.empty()) {
tlsconf.tls_proto_list = tlsconf.tls_proto_list = util::split_str(DEFAULT_TLS_PROTO_LIST, ',');
util::parse_config_str_list(DEFAULT_TLS_PROTO_LIST);
} }
tlsconf.tls_proto_mask = ssl::create_tls_proto_mask(tlsconf.tls_proto_list); tlsconf.tls_proto_mask = ssl::create_tls_proto_mask(tlsconf.tls_proto_list);
...@@ -2466,7 +2466,7 @@ int process_options(Config *config, ...@@ -2466,7 +2466,7 @@ int process_options(Config *config,
if (listenerconf.addrs.empty()) { if (listenerconf.addrs.empty()) {
UpstreamAddr addr{}; UpstreamAddr addr{};
addr.host = ImmutableString::from_lit("*"); addr.host = StringRef::from_lit("*");
addr.port = 3000; addr.port = 3000;
addr.tls = true; addr.tls = true;
addr.family = AF_INET; addr.family = AF_INET;
...@@ -2567,10 +2567,14 @@ int process_options(Config *config, ...@@ -2567,10 +2567,14 @@ int process_options(Config *config,
if (fwdconf.by_node_type == FORWARDED_NODE_OBFUSCATED && if (fwdconf.by_node_type == FORWARDED_NODE_OBFUSCATED &&
fwdconf.by_obfuscated.empty()) { fwdconf.by_obfuscated.empty()) {
// 2 for '_' and terminal NULL
auto iov = make_byte_ref(config->balloc, SHRPX_OBFUSCATED_NODE_LENGTH + 2);
auto p = iov.base;
*p++ = '_';
std::mt19937 gen(rd()); std::mt19937 gen(rd());
auto &dst = fwdconf.by_obfuscated; p = util::random_alpha_digit(p, p + SHRPX_OBFUSCATED_NODE_LENGTH, gen);
dst = "_"; *p = '\0';
dst += util::random_alpha_digit(gen, SHRPX_OBFUSCATED_NODE_LENGTH); fwdconf.by_obfuscated = StringRef{iov.base, p};
} }
if (config->http2.upstream.debug.frame_debug) { if (config->http2.upstream.debug.frame_debug) {
...@@ -2616,12 +2620,13 @@ void reload_config(WorkerProcess *wp) { ...@@ -2616,12 +2620,13 @@ void reload_config(WorkerProcess *wp) {
LOG(NOTICE) << "Reloading configuration"; LOG(NOTICE) << "Reloading configuration";
auto cur_config = get_config(); auto cur_config = mod_config();
auto new_config = make_unique<Config>(); auto new_config = make_unique<Config>();
fill_default_config(new_config.get()); fill_default_config(new_config.get());
new_config->conf_path = cur_config->conf_path; new_config->conf_path =
make_string_ref(new_config->balloc, cur_config->conf_path);
// daemon option is ignored here. // daemon option is ignored here.
new_config->daemon = cur_config->daemon; new_config->daemon = cur_config->daemon;
// loop is reused, and ev_loop_flags gets ignored // loop is reused, and ev_loop_flags gets ignored
...@@ -2633,7 +2638,7 @@ void reload_config(WorkerProcess *wp) { ...@@ -2633,7 +2638,7 @@ void reload_config(WorkerProcess *wp) {
return; return;
} }
auto iaddrs = get_inherited_addr_from_config(cur_config); auto iaddrs = get_inherited_addr_from_config(new_config->balloc, cur_config);
if (create_acceptor_socket(new_config.get(), iaddrs) != 0) { if (create_acceptor_socket(new_config.get(), iaddrs) != 0) {
close_not_inherited_fd(new_config.get(), iaddrs); close_not_inherited_fd(new_config.get(), iaddrs);
...@@ -3027,7 +3032,8 @@ int main(int argc, char **argv) { ...@@ -3027,7 +3032,8 @@ int main(int argc, char **argv) {
break; break;
case 12: case 12:
// --conf // --conf
mod_config()->conf_path = ImmutableString{optarg}; mod_config()->conf_path =
make_string_ref(mod_config()->balloc, StringRef{optarg});
break; break;
case 14: case 14:
// --syslog-facility // --syslog-facility
......
This diff is collapsed.
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "shrpx_connection.h" #include "shrpx_connection.h"
#include "buffer.h" #include "buffer.h"
#include "memchunk.h" #include "memchunk.h"
#include "allocator.h"
using namespace nghttp2; using namespace nghttp2;
...@@ -54,8 +55,8 @@ struct DownstreamAddr; ...@@ -54,8 +55,8 @@ struct DownstreamAddr;
class ClientHandler { class ClientHandler {
public: public:
ClientHandler(Worker *worker, int fd, SSL *ssl, const char *ipaddr, ClientHandler(Worker *worker, int fd, SSL *ssl, const StringRef &ipaddr,
const char *port, int family, const UpstreamAddr *faddr); const StringRef &port, int family, const UpstreamAddr *faddr);
~ClientHandler(); ~ClientHandler();
int noop(); int noop();
...@@ -90,8 +91,7 @@ public: ...@@ -90,8 +91,7 @@ public:
void reset_upstream_write_timeout(ev_tstamp t); void reset_upstream_write_timeout(ev_tstamp t);
int validate_next_proto(); int validate_next_proto();
const std::string &get_ipaddr() const; const StringRef &get_ipaddr() const;
const std::string &get_port() const;
bool get_should_close_after_write() const; bool get_should_close_after_write() const;
void set_should_close_after_write(bool f); void set_should_close_after_write(bool f);
Upstream *get_upstream(); Upstream *get_upstream();
...@@ -162,21 +162,27 @@ public: ...@@ -162,21 +162,27 @@ public:
// Returns TLS SNI extension value client sent in this connection. // Returns TLS SNI extension value client sent in this connection.
StringRef get_tls_sni() const; StringRef get_tls_sni() const;
BlockAllocator &get_block_allocator();
private: private:
// Allocator to allocate memory for connection-wide objects. Make
// sure that the allocations must be bounded, and not proportional
// to the number of requests.
BlockAllocator balloc_;
Connection conn_; Connection conn_;
ev_timer reneg_shutdown_timer_; ev_timer reneg_shutdown_timer_;
std::unique_ptr<Upstream> upstream_; std::unique_ptr<Upstream> upstream_;
// IP address of client. If UNIX domain socket is used, this is // IP address of client. If UNIX domain socket is used, this is
// "localhost". // "localhost".
std::string ipaddr_; StringRef ipaddr_;
std::string port_; StringRef port_;
// The ALPN identifier negotiated for this connection. // The ALPN identifier negotiated for this connection.
std::string alpn_; StringRef alpn_;
// The client address used in "for" parameter of Forwarded header // The client address used in "for" parameter of Forwarded header
// field. // field.
std::string forwarded_for_; StringRef forwarded_for_;
// lowercased TLS SNI which client sent. // lowercased TLS SNI which client sent.
std::string sni_; StringRef sni_;
std::function<int(ClientHandler &)> read_, write_; std::function<int(ClientHandler &)> read_, write_;
std::function<int(ClientHandler &)> on_read_, on_write_; std::function<int(ClientHandler &)> on_read_, on_write_;
// Address of frontend listening socket // Address of frontend listening socket
......
This diff is collapsed.
This diff is collapsed.
...@@ -37,40 +37,45 @@ ...@@ -37,40 +37,45 @@
namespace shrpx { namespace shrpx {
void test_shrpx_config_parse_header(void) { void test_shrpx_config_parse_header(void) {
auto p = parse_header(StringRef::from_lit("a: b")); BlockAllocator balloc(4096, 4096);
auto p = parse_header(balloc, StringRef::from_lit("a: b"));
CU_ASSERT("a" == p.name); CU_ASSERT("a" == p.name);
CU_ASSERT("b" == p.value); CU_ASSERT("b" == p.value);
p = parse_header(StringRef::from_lit("a: b")); p = parse_header(balloc, StringRef::from_lit("a: b"));
CU_ASSERT("a" == p.name); CU_ASSERT("a" == p.name);
CU_ASSERT("b" == p.value); CU_ASSERT("b" == p.value);
p = parse_header(StringRef::from_lit(":a: b")); p = parse_header(balloc, StringRef::from_lit(":a: b"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
p = parse_header(StringRef::from_lit("a: :b")); p = parse_header(balloc, StringRef::from_lit("a: :b"));
CU_ASSERT("a" == p.name); CU_ASSERT("a" == p.name);
CU_ASSERT(":b" == p.value); CU_ASSERT(":b" == p.value);
p = parse_header(StringRef::from_lit(": b")); p = parse_header(balloc, StringRef::from_lit(": b"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
p = parse_header(StringRef::from_lit("alpha: bravo charlie")); p = parse_header(balloc, StringRef::from_lit("alpha: bravo charlie"));
CU_ASSERT("alpha" == p.name); CU_ASSERT("alpha" == p.name);
CU_ASSERT("bravo charlie" == p.value); CU_ASSERT("bravo charlie" == p.value);
p = parse_header(StringRef::from_lit("a,: b")); p = parse_header(balloc, StringRef::from_lit("a,: b"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
p = parse_header(StringRef::from_lit("a: b\x0a")); p = parse_header(balloc, StringRef::from_lit("a: b\x0a"));
CU_ASSERT(p.name.empty()); CU_ASSERT(p.name.empty());
} }
void test_shrpx_config_parse_log_format(void) { void test_shrpx_config_parse_log_format(void) {
auto res = parse_log_format(StringRef::from_lit( BlockAllocator balloc(4096, 4096);
R"($remote_addr - $remote_user [$time_local] )"
R"("$request" $status $body_bytes_sent )" auto res = parse_log_format(
R"("${http_referer}" $http_host "$http_user_agent")")); balloc, StringRef::from_lit(
R"($remote_addr - $remote_user [$time_local] )"
R"("$request" $status $body_bytes_sent )"
R"("${http_referer}" $http_host "$http_user_agent")"));
CU_ASSERT(16 == res.size()); CU_ASSERT(16 == res.size());
CU_ASSERT(SHRPX_LOGF_REMOTE_ADDR == res[0].type); CU_ASSERT(SHRPX_LOGF_REMOTE_ADDR == res[0].type);
...@@ -115,35 +120,35 @@ void test_shrpx_config_parse_log_format(void) { ...@@ -115,35 +120,35 @@ void test_shrpx_config_parse_log_format(void) {
CU_ASSERT(SHRPX_LOGF_LITERAL == res[15].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[15].type);
CU_ASSERT("\"" == res[15].value); CU_ASSERT("\"" == res[15].value);
res = parse_log_format(StringRef::from_lit("$")); res = parse_log_format(balloc, StringRef::from_lit("$"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("$" == res[0].value); CU_ASSERT("$" == res[0].value);
res = parse_log_format(StringRef::from_lit("${")); res = parse_log_format(balloc, StringRef::from_lit("${"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${" == res[0].value); CU_ASSERT("${" == res[0].value);
res = parse_log_format(StringRef::from_lit("${a")); res = parse_log_format(balloc, StringRef::from_lit("${a"));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${a" == res[0].value); CU_ASSERT("${a" == res[0].value);
res = parse_log_format(StringRef::from_lit("${a ")); res = parse_log_format(balloc, StringRef::from_lit("${a "));
CU_ASSERT(1 == res.size()); CU_ASSERT(1 == res.size());
CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type); CU_ASSERT(SHRPX_LOGF_LITERAL == res[0].type);
CU_ASSERT("${a " == res[0].value); CU_ASSERT("${a " == res[0].value);
res = parse_log_format(StringRef::from_lit("$$remote_addr")); res = parse_log_format(balloc, StringRef::from_lit("$$remote_addr"));
CU_ASSERT(2 == res.size()); CU_ASSERT(2 == res.size());
...@@ -168,8 +173,8 @@ void test_shrpx_config_read_tls_ticket_key_file(void) { ...@@ -168,8 +173,8 @@ void test_shrpx_config_read_tls_ticket_key_file(void) {
close(fd1); close(fd1);
close(fd2); close(fd2);
auto ticket_keys = auto ticket_keys = read_tls_ticket_key_file(
read_tls_ticket_key_file({file1, file2}, EVP_aes_128_cbc(), EVP_sha256()); {StringRef{file1}, StringRef{file2}}, EVP_aes_128_cbc(), EVP_sha256());
unlink(file1); unlink(file1);
unlink(file2); unlink(file2);
CU_ASSERT(ticket_keys.get() != nullptr); CU_ASSERT(ticket_keys.get() != nullptr);
...@@ -211,8 +216,8 @@ void test_shrpx_config_read_tls_ticket_key_file_aes_256(void) { ...@@ -211,8 +216,8 @@ void test_shrpx_config_read_tls_ticket_key_file_aes_256(void) {
close(fd1); close(fd1);
close(fd2); close(fd2);
auto ticket_keys = auto ticket_keys = read_tls_ticket_key_file(
read_tls_ticket_key_file({file1, file2}, EVP_aes_256_cbc(), EVP_sha256()); {StringRef{file1}, StringRef{file2}}, EVP_aes_256_cbc(), EVP_sha256());
unlink(file1); unlink(file1);
unlink(file2); unlink(file2);
CU_ASSERT(ticket_keys.get() != nullptr); CU_ASSERT(ticket_keys.get() != nullptr);
......
...@@ -224,8 +224,8 @@ int ConnectionHandler::create_single_worker() { ...@@ -224,8 +224,8 @@ int ConnectionHandler::create_single_worker() {
#ifdef HAVE_NEVERBLEED #ifdef HAVE_NEVERBLEED
nb_.get(), nb_.get(),
#endif // HAVE_NEVERBLEED #endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{memcachedconf.cert_file}, tlsconf.cacert, memcachedconf.cert_file, memcachedconf.private_key_file,
StringRef{memcachedconf.private_key_file}, nullptr); nullptr);
all_ssl_ctx_.push_back(session_cache_ssl_ctx); all_ssl_ctx_.push_back(session_cache_ssl_ctx);
} }
...@@ -280,8 +280,8 @@ int ConnectionHandler::create_worker_thread(size_t num) { ...@@ -280,8 +280,8 @@ int ConnectionHandler::create_worker_thread(size_t num) {
#ifdef HAVE_NEVERBLEED #ifdef HAVE_NEVERBLEED
nb_.get(), nb_.get(),
#endif // HAVE_NEVERBLEED #endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{memcachedconf.cert_file}, tlsconf.cacert, memcachedconf.cert_file,
StringRef{memcachedconf.private_key_file}, nullptr); memcachedconf.private_key_file, nullptr);
all_ssl_ctx_.push_back(session_cache_ssl_ctx); all_ssl_ctx_.push_back(session_cache_ssl_ctx);
} }
auto worker = make_unique<Worker>( auto worker = make_unique<Worker>(
...@@ -835,8 +835,8 @@ SSL_CTX *ConnectionHandler::create_tls_ticket_key_memcached_ssl_ctx() { ...@@ -835,8 +835,8 @@ SSL_CTX *ConnectionHandler::create_tls_ticket_key_memcached_ssl_ctx() {
#ifdef HAVE_NEVERBLEED #ifdef HAVE_NEVERBLEED
nb_.get(), nb_.get(),
#endif // HAVE_NEVERBLEED #endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{memcachedconf.cert_file}, tlsconf.cacert, memcachedconf.cert_file, memcachedconf.private_key_file,
StringRef{memcachedconf.private_key_file}, nullptr); nullptr);
all_ssl_ctx_.push_back(ssl_ctx); all_ssl_ctx_.push_back(ssl_ctx);
......
...@@ -46,12 +46,11 @@ StringRef create_error_html(BlockAllocator &balloc, unsigned int http_status) { ...@@ -46,12 +46,11 @@ StringRef create_error_html(BlockAllocator &balloc, unsigned int http_status) {
} }
auto status_string = http2::get_status_string(balloc, http_status); auto status_string = http2::get_status_string(balloc, http_status);
const auto &server_name = httpconf.server_name;
return concat_string_ref( return concat_string_ref(
balloc, StringRef::from_lit(R"(<!DOCTYPE html><html lang="en"><title>)"), balloc, StringRef::from_lit(R"(<!DOCTYPE html><html lang="en"><title>)"),
status_string, StringRef::from_lit("</title><body><h1>"), status_string, status_string, StringRef::from_lit("</title><body><h1>"), status_string,
StringRef::from_lit("</h1><footer>"), StringRef{server_name}, StringRef::from_lit("</h1><footer>"), httpconf.server_name,
StringRef::from_lit("</footer></body></html>")); StringRef::from_lit("</footer></body></html>"));
} }
......
...@@ -357,7 +357,7 @@ int Http2DownstreamConnection::push_request_headers() { ...@@ -357,7 +357,7 @@ int Http2DownstreamConnection::push_request_headers() {
if (xffconf.add) { if (xffconf.add) {
StringRef xff_value; StringRef xff_value;
auto addr = StringRef{upstream->get_client_handler()->get_ipaddr()}; const auto &addr = upstream->get_client_handler()->get_ipaddr();
if (xff) { if (xff) {
xff_value = concat_string_ref(balloc, xff->value, xff_value = concat_string_ref(balloc, xff->value,
StringRef::from_lit(", "), addr); StringRef::from_lit(", "), addr);
......
...@@ -108,15 +108,16 @@ int on_stream_close_callback(nghttp2_session *session, int32_t stream_id, ...@@ -108,15 +108,16 @@ int on_stream_close_callback(nghttp2_session *session, int32_t stream_id,
int Http2Upstream::upgrade_upstream(HttpsUpstream *http) { int Http2Upstream::upgrade_upstream(HttpsUpstream *http) {
int rv; int rv;
auto http2_settings = http->get_downstream()->get_http2_settings().str(); auto &balloc = http->get_downstream()->get_block_allocator();
util::to_base64(http2_settings);
auto settings_payload = auto http2_settings = http->get_downstream()->get_http2_settings();
base64::decode(std::begin(http2_settings), std::end(http2_settings)); http2_settings = util::to_base64(balloc, http2_settings);
auto settings_payload = base64::decode(balloc, std::begin(http2_settings),
std::end(http2_settings));
rv = nghttp2_session_upgrade2( rv = nghttp2_session_upgrade2(
session_, reinterpret_cast<const uint8_t *>(settings_payload.c_str()), session_, settings_payload.byte(), settings_payload.size(),
settings_payload.size(),
http->get_downstream()->request().method == HTTP_HEAD, nullptr); http->get_downstream()->request().method == HTTP_HEAD, nullptr);
if (rv != 0) { if (rv != 0) {
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {
...@@ -1429,8 +1430,8 @@ int Http2Upstream::send_reply(Downstream *downstream, const uint8_t *body, ...@@ -1429,8 +1430,8 @@ int Http2Upstream::send_reply(Downstream *downstream, const uint8_t *body,
} }
if (!resp.fs.header(http2::HD_SERVER)) { if (!resp.fs.header(http2::HD_SERVER)) {
nva.push_back(http2::make_nv_ls_nocopy( nva.push_back(
"server", StringRef{get_config()->http.server_name})); http2::make_nv_ls_nocopy("server", get_config()->http.server_name));
} }
for (auto &p : httpconf.add_response_headers) { for (auto &p : httpconf.add_response_headers) {
...@@ -1481,8 +1482,7 @@ int Http2Upstream::error_reply(Downstream *downstream, ...@@ -1481,8 +1482,7 @@ int Http2Upstream::error_reply(Downstream *downstream,
auto nva = std::array<nghttp2_nv, 5>{ auto nva = std::array<nghttp2_nv, 5>{
{http2::make_nv_ls_nocopy(":status", response_status), {http2::make_nv_ls_nocopy(":status", response_status),
http2::make_nv_ll("content-type", "text/html; charset=UTF-8"), http2::make_nv_ll("content-type", "text/html; charset=UTF-8"),
http2::make_nv_ls_nocopy("server", http2::make_nv_ls_nocopy("server", get_config()->http.server_name),
StringRef{get_config()->http.server_name}),
http2::make_nv_ls_nocopy("content-length", content_length), http2::make_nv_ls_nocopy("content-length", content_length),
http2::make_nv_ls_nocopy("date", date)}}; http2::make_nv_ls_nocopy("date", date)}};
...@@ -1629,8 +1629,7 @@ int Http2Upstream::on_downstream_header_complete(Downstream *downstream) { ...@@ -1629,8 +1629,7 @@ int Http2Upstream::on_downstream_header_complete(Downstream *downstream) {
http2::copy_headers_to_nva_nocopy(nva, resp.fs.headers()); http2::copy_headers_to_nva_nocopy(nva, resp.fs.headers());
if (!get_config()->http2_proxy && !httpconf.no_server_rewrite) { if (!get_config()->http2_proxy && !httpconf.no_server_rewrite) {
nva.push_back( nva.push_back(http2::make_nv_ls_nocopy("server", httpconf.server_name));
http2::make_nv_ls_nocopy("server", StringRef{httpconf.server_name}));
} else { } else {
auto server = resp.fs.header(http2::HD_SERVER); auto server = resp.fs.header(http2::HD_SERVER);
if (server) { if (server) {
......
...@@ -960,13 +960,14 @@ std::unique_ptr<Downstream> HttpsUpstream::pop_downstream() { ...@@ -960,13 +960,14 @@ std::unique_ptr<Downstream> HttpsUpstream::pop_downstream() {
} }
namespace { namespace {
void write_altsvc(DefaultMemchunks *buf, const AltSvc &altsvc) { void write_altsvc(DefaultMemchunks *buf, BlockAllocator &balloc,
buf->append(util::percent_encode_token(altsvc.protocol_id)); const AltSvc &altsvc) {
buf->append(util::percent_encode_token(balloc, altsvc.protocol_id));
buf->append("=\""); buf->append("=\"");
buf->append(util::quote_string(altsvc.host)); buf->append(util::quote_string(balloc, altsvc.host));
buf->append(":"); buf->append(':');
buf->append(altsvc.service); buf->append(altsvc.service);
buf->append("\""); buf->append('"');
} }
} // namespace } // namespace
...@@ -1073,10 +1074,10 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) { ...@@ -1073,10 +1074,10 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) {
buf->append("Alt-Svc: "); buf->append("Alt-Svc: ");
auto &altsvcs = httpconf.altsvcs; auto &altsvcs = httpconf.altsvcs;
write_altsvc(buf, altsvcs[0]); write_altsvc(buf, downstream->get_block_allocator(), altsvcs[0]);
for (size_t i = 1; i < altsvcs.size(); ++i) { for (size_t i = 1; i < altsvcs.size(); ++i) {
buf->append(", "); buf->append(", ");
write_altsvc(buf, altsvcs[i]); write_altsvc(buf, downstream->get_block_allocator(), altsvcs[i]);
} }
buf->append("\r\n"); buf->append("\r\n");
} }
......
...@@ -290,7 +290,7 @@ void upstream_accesslog(const std::vector<LogFragment> &lfv, ...@@ -290,7 +290,7 @@ void upstream_accesslog(const std::vector<LogFragment> &lfv,
break; break;
case SHRPX_LOGF_HTTP: case SHRPX_LOGF_HTTP:
if (req) { if (req) {
auto hd = req->fs.header(StringRef(lf.value)); auto hd = req->fs.header(lf.value);
if (hd) { if (hd) {
std::tie(p, avail) = copy((*hd).value, avail, p); std::tie(p, avail) = copy((*hd).value, avail, p);
break; break;
......
...@@ -137,10 +137,10 @@ enum LogFragmentType { ...@@ -137,10 +137,10 @@ enum LogFragmentType {
}; };
struct LogFragment { struct LogFragment {
LogFragment(LogFragmentType type, ImmutableString value = ImmutableString()) LogFragment(LogFragmentType type, StringRef value = StringRef::from_lit(""))
: type(type), value(std::move(value)) {} : type(type), value(std::move(value)) {}
LogFragmentType type; LogFragmentType type;
ImmutableString value; StringRef value;
}; };
struct LogSpec { struct LogSpec {
......
...@@ -100,7 +100,7 @@ MemcachedConnection::MemcachedConnection(const Address *addr, ...@@ -100,7 +100,7 @@ MemcachedConnection::MemcachedConnection(const Address *addr,
connectcb, readcb, timeoutcb, this, 0, 0., PROTO_MEMCACHED), connectcb, readcb, timeoutcb, this, 0, 0., PROTO_MEMCACHED),
do_read_(&MemcachedConnection::noop), do_read_(&MemcachedConnection::noop),
do_write_(&MemcachedConnection::noop), do_write_(&MemcachedConnection::noop),
sni_name_(sni_name.str()), sni_name_(sni_name),
connect_blocker_(gen, loop, [] {}, [] {}), connect_blocker_(gen, loop, [] {}, [] {}),
parse_state_{}, parse_state_{},
addr_(addr), addr_(addr),
...@@ -268,7 +268,7 @@ int MemcachedConnection::tls_handshake() { ...@@ -268,7 +268,7 @@ int MemcachedConnection::tls_handshake() {
auto &tlsconf = get_config()->tls; auto &tlsconf = get_config()->tls;
if (!tlsconf.insecure && if (!tlsconf.insecure &&
ssl::check_cert(conn_.tls.ssl, addr_, StringRef(sni_name_)) != 0) { ssl::check_cert(conn_.tls.ssl, addr_, sni_name_) != 0) {
connect_blocker_.on_failure(); connect_blocker_.on_failure();
return -1; return -1;
} }
......
...@@ -135,7 +135,7 @@ private: ...@@ -135,7 +135,7 @@ private:
std::deque<std::unique_ptr<MemcachedRequest>> sendq_; std::deque<std::unique_ptr<MemcachedRequest>> sendq_;
std::deque<MemcachedSendbuf> sendbufv_; std::deque<MemcachedSendbuf> sendbufv_;
std::function<int(MemcachedConnection &)> do_read_, do_write_; std::function<int(MemcachedConnection &)> do_read_, do_write_;
std::string sni_name_; StringRef sni_name_;
ssl::TLSSessionCache tls_session_cache_; ssl::TLSSessionCache tls_session_cache_;
ConnectBlocker connect_blocker_; ConnectBlocker connect_blocker_;
MemcachedParseState parse_state_; MemcachedParseState parse_state_;
......
...@@ -103,7 +103,7 @@ int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { ...@@ -103,7 +103,7 @@ int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
} // namespace } // namespace
int set_alpn_prefs(std::vector<unsigned char> &out, int set_alpn_prefs(std::vector<unsigned char> &out,
const std::vector<std::string> &protos) { const std::vector<StringRef> &protos) {
size_t len = 0; size_t len = 0;
for (const auto &proto : protos) { for (const auto &proto : protos) {
...@@ -125,8 +125,7 @@ int set_alpn_prefs(std::vector<unsigned char> &out, ...@@ -125,8 +125,7 @@ int set_alpn_prefs(std::vector<unsigned char> &out,
for (const auto &proto : protos) { for (const auto &proto : protos) {
*ptr++ = proto.size(); *ptr++ = proto.size();
memcpy(ptr, proto.c_str(), proto.size()); ptr = std::copy(std::begin(proto), std::end(proto), ptr);
ptr += proto.size();
} }
return 0; return 0;
...@@ -243,6 +242,7 @@ int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) { ...@@ -243,6 +242,7 @@ int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) {
auto handler = static_cast<ClientHandler *>(conn->data); auto handler = static_cast<ClientHandler *>(conn->data);
auto worker = handler->get_worker(); auto worker = handler->get_worker();
auto dispatcher = worker->get_session_cache_memcached_dispatcher(); auto dispatcher = worker->get_session_cache_memcached_dispatcher();
auto &balloc = handler->get_block_allocator();
const unsigned char *id; const unsigned char *id;
unsigned int idlen; unsigned int idlen;
...@@ -256,7 +256,8 @@ int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) { ...@@ -256,7 +256,8 @@ int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) {
auto req = make_unique<MemcachedRequest>(); auto req = make_unique<MemcachedRequest>();
req->op = MEMCACHED_OP_ADD; req->op = MEMCACHED_OP_ADD;
req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str(); req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str();
req->key += util::format_hex(id, idlen); req->key +=
util::format_hex(balloc, StringRef{id, static_cast<size_t>(idlen)});
auto sessionlen = i2d_SSL_SESSION(session, nullptr); auto sessionlen = i2d_SSL_SESSION(session, nullptr);
req->value.resize(sessionlen); req->value.resize(sessionlen);
...@@ -295,6 +296,7 @@ SSL_SESSION *tls_session_get_cb(SSL *ssl, ...@@ -295,6 +296,7 @@ SSL_SESSION *tls_session_get_cb(SSL *ssl,
auto handler = static_cast<ClientHandler *>(conn->data); auto handler = static_cast<ClientHandler *>(conn->data);
auto worker = handler->get_worker(); auto worker = handler->get_worker();
auto dispatcher = worker->get_session_cache_memcached_dispatcher(); auto dispatcher = worker->get_session_cache_memcached_dispatcher();
auto &balloc = handler->get_block_allocator();
if (conn->tls.cached_session) { if (conn->tls.cached_session) {
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {
...@@ -318,7 +320,8 @@ SSL_SESSION *tls_session_get_cb(SSL *ssl, ...@@ -318,7 +320,8 @@ SSL_SESSION *tls_session_get_cb(SSL *ssl,
auto req = make_unique<MemcachedRequest>(); auto req = make_unique<MemcachedRequest>();
req->op = MEMCACHED_OP_GET; req->op = MEMCACHED_OP_GET;
req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str(); req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str();
req->key += util::format_hex(id, idlen); req->key +=
util::format_hex(balloc, StringRef{id, static_cast<size_t>(idlen)});
req->cb = [conn](MemcachedRequest *, MemcachedResult res) { req->cb = [conn](MemcachedRequest *, MemcachedResult res) {
if (LOG_ENABLED(INFO)) { if (LOG_ENABLED(INFO)) {
LOG(INFO) << "Memcached: returned status code " << res.status_code; LOG(INFO) << "Memcached: returned status code " << res.status_code;
...@@ -465,8 +468,7 @@ int alpn_select_proto_cb(SSL *ssl, const unsigned char **out, ...@@ -465,8 +468,7 @@ int alpn_select_proto_cb(SSL *ssl, const unsigned char **out,
auto proto_len = *p; auto proto_len = *p;
if (proto_id + proto_len <= end && if (proto_id + proto_len <= end &&
util::streq(StringRef{target_proto_id}, util::streq(target_proto_id, StringRef{proto_id, proto_len})) {
StringRef{proto_id, proto_len})) {
*out = reinterpret_cast<const unsigned char *>(proto_id); *out = reinterpret_cast<const unsigned char *>(proto_id);
*outlen = proto_len; *outlen = proto_len;
...@@ -493,7 +495,7 @@ constexpr TLSProtocol TLS_PROTOS[] = { ...@@ -493,7 +495,7 @@ constexpr TLSProtocol TLS_PROTOS[] = {
TLSProtocol{StringRef::from_lit("TLSv1.1"), SSL_OP_NO_TLSv1_1}, TLSProtocol{StringRef::from_lit("TLSv1.1"), SSL_OP_NO_TLSv1_1},
TLSProtocol{StringRef::from_lit("TLSv1.0"), SSL_OP_NO_TLSv1}}; TLSProtocol{StringRef::from_lit("TLSv1.0"), SSL_OP_NO_TLSv1}};
long int create_tls_proto_mask(const std::vector<std::string> &tls_proto_list) { long int create_tls_proto_mask(const std::vector<StringRef> &tls_proto_list) {
long int res = 0; long int res = 0;
for (auto &supported : TLS_PROTOS) { for (auto &supported : TLS_PROTOS) {
...@@ -829,16 +831,16 @@ SSL *create_ssl(SSL_CTX *ssl_ctx) { ...@@ -829,16 +831,16 @@ SSL *create_ssl(SSL_CTX *ssl_ctx) {
ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
int addrlen, const UpstreamAddr *faddr) { int addrlen, const UpstreamAddr *faddr) {
char host[NI_MAXHOST]; std::array<char, NI_MAXHOST> host;
char service[NI_MAXSERV]; std::array<char, NI_MAXSERV> service;
int rv; int rv;
if (addr->sa_family == AF_UNIX) { if (addr->sa_family == AF_UNIX) {
std::copy_n("localhost", sizeof("localhost"), host); std::copy_n("localhost", sizeof("localhost"), std::begin(host));
service[0] = '\0'; service[0] = '\0';
} else { } else {
rv = getnameinfo(addr, addrlen, host, sizeof(host), service, rv = getnameinfo(addr, addrlen, host.data(), host.size(), service.data(),
sizeof(service), NI_NUMERICHOST | NI_NUMERICSERV); service.size(), NI_NUMERICHOST | NI_NUMERICSERV);
if (rv != 0) { if (rv != 0) {
LOG(ERROR) << "getnameinfo() failed: " << gai_strerror(rv); LOG(ERROR) << "getnameinfo() failed: " << gai_strerror(rv);
...@@ -867,8 +869,8 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, ...@@ -867,8 +869,8 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
} }
} }
return new ClientHandler(worker, fd, ssl, host, service, addr->sa_family, return new ClientHandler(worker, fd, ssl, StringRef{host.data()},
faddr); StringRef{service.data()}, addr->sa_family, faddr);
} }
bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname) { bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname) {
...@@ -1316,10 +1318,10 @@ int cert_lookup_tree_add_cert_from_x509(CertLookupTree *lt, size_t idx, ...@@ -1316,10 +1318,10 @@ int cert_lookup_tree_add_cert_from_x509(CertLookupTree *lt, size_t idx,
return 0; return 0;
} }
bool in_proto_list(const std::vector<std::string> &protos, bool in_proto_list(const std::vector<StringRef> &protos,
const StringRef &needle) { const StringRef &needle) {
for (auto &proto : protos) { for (auto &proto : protos) {
if (util::streq(StringRef{proto}, needle)) { if (util::streq(proto, needle)) {
return true; return true;
} }
} }
...@@ -1443,8 +1445,8 @@ SSL_CTX *setup_downstream_client_ssl_context( ...@@ -1443,8 +1445,8 @@ SSL_CTX *setup_downstream_client_ssl_context(
#ifdef HAVE_NEVERBLEED #ifdef HAVE_NEVERBLEED
nb, nb,
#endif // HAVE_NEVERBLEED #endif // HAVE_NEVERBLEED
StringRef{tlsconf.cacert}, StringRef{tlsconf.client.cert_file}, tlsconf.cacert, tlsconf.client.cert_file, tlsconf.client.private_key_file,
StringRef{tlsconf.client.private_key_file}, select_next_proto_cb); select_next_proto_cb);
} }
void setup_downstream_http2_alpn(SSL *ssl) { void setup_downstream_http2_alpn(SSL *ssl) {
......
...@@ -101,11 +101,6 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, ...@@ -101,11 +101,6 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
int check_cert(SSL *ssl, const Address *addr, const StringRef &host); int check_cert(SSL *ssl, const Address *addr, const StringRef &host);
int check_cert(SSL *ssl, const DownstreamAddr *addr); int check_cert(SSL *ssl, const DownstreamAddr *addr);
// Retrieves DNS and IP address in subjectAltNames and commonName from
// the |cert|.
void get_altnames(X509 *cert, std::vector<std::string> &dns_names,
std::vector<std::string> &ip_addrs, std::string &common_name);
struct WildcardRevPrefix { struct WildcardRevPrefix {
WildcardRevPrefix(const StringRef &prefix, size_t idx) WildcardRevPrefix(const StringRef &prefix, size_t idx)
: prefix(std::begin(prefix), std::end(prefix)), idx(idx) {} : prefix(std::begin(prefix), std::end(prefix)), idx(idx) {}
...@@ -172,7 +167,7 @@ int cert_lookup_tree_add_cert_from_x509(CertLookupTree *lt, size_t idx, ...@@ -172,7 +167,7 @@ int cert_lookup_tree_add_cert_from_x509(CertLookupTree *lt, size_t idx,
// Returns true if |proto| is included in the // Returns true if |proto| is included in the
// protocol list |protos|. // protocol list |protos|.
bool in_proto_list(const std::vector<std::string> &protos, bool in_proto_list(const std::vector<StringRef> &protos,
const StringRef &proto); const StringRef &proto);
// Returns true if security requirement for HTTP/2 is fulfilled. // Returns true if security requirement for HTTP/2 is fulfilled.
...@@ -181,10 +176,10 @@ bool check_http2_requirement(SSL *ssl); ...@@ -181,10 +176,10 @@ bool check_http2_requirement(SSL *ssl);
// Returns SSL/TLS option mask to disable SSL/TLS protocol version not // Returns SSL/TLS option mask to disable SSL/TLS protocol version not
// included in |tls_proto_list|. The returned mask can be directly // included in |tls_proto_list|. The returned mask can be directly
// passed to SSL_CTX_set_options(). // passed to SSL_CTX_set_options().
long int create_tls_proto_mask(const std::vector<std::string> &tls_proto_list); long int create_tls_proto_mask(const std::vector<StringRef> &tls_proto_list);
int set_alpn_prefs(std::vector<unsigned char> &out, int set_alpn_prefs(std::vector<unsigned char> &out,
const std::vector<std::string> &protos); const std::vector<StringRef> &protos);
// Setups server side SSL_CTX. This function inspects get_config() // Setups server side SSL_CTX. This function inspects get_config()
// and if upstream_no_tls is true, returns nullptr. Otherwise // and if upstream_no_tls is true, returns nullptr. Otherwise
......
...@@ -182,7 +182,8 @@ void Worker::replace_downstream_config( ...@@ -182,7 +182,8 @@ void Worker::replace_downstream_config(
auto &dst = downstream_addr_groups_[i]; auto &dst = downstream_addr_groups_[i];
dst = std::make_shared<DownstreamAddrGroup>(); dst = std::make_shared<DownstreamAddrGroup>();
dst->pattern = src.pattern; dst->pattern =
ImmutableString{std::begin(src.pattern), std::end(src.pattern)};
auto shared_addr = std::make_shared<SharedDownstreamAddr>(); auto shared_addr = std::make_shared<SharedDownstreamAddr>();
...@@ -198,13 +199,14 @@ void Worker::replace_downstream_config( ...@@ -198,13 +199,14 @@ void Worker::replace_downstream_config(
auto &dst_addr = shared_addr->addrs[j]; auto &dst_addr = shared_addr->addrs[j];
dst_addr.addr = src_addr.addr; dst_addr.addr = src_addr.addr;
dst_addr.host = src_addr.host; dst_addr.host = make_string_ref(shared_addr->balloc, src_addr.host);
dst_addr.hostport = src_addr.hostport; dst_addr.hostport =
make_string_ref(shared_addr->balloc, src_addr.hostport);
dst_addr.port = src_addr.port; dst_addr.port = src_addr.port;
dst_addr.host_unix = src_addr.host_unix; dst_addr.host_unix = src_addr.host_unix;
dst_addr.proto = src_addr.proto; dst_addr.proto = src_addr.proto;
dst_addr.tls = src_addr.tls; dst_addr.tls = src_addr.tls;
dst_addr.sni = src_addr.sni; dst_addr.sni = make_string_ref(shared_addr->balloc, src_addr.sni);
dst_addr.fall = src_addr.fall; dst_addr.fall = src_addr.fall;
dst_addr.rise = src_addr.rise; dst_addr.rise = src_addr.rise;
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include "shrpx_ssl.h" #include "shrpx_ssl.h"
#include "shrpx_live_check.h" #include "shrpx_live_check.h"
#include "shrpx_connect_blocker.h" #include "shrpx_connect_blocker.h"
#include "allocator.h"
using namespace nghttp2; using namespace nghttp2;
...@@ -75,15 +76,15 @@ struct DownstreamAddr { ...@@ -75,15 +76,15 @@ struct DownstreamAddr {
Address addr; Address addr;
// backend address. If |host_unix| is true, this is UNIX domain // backend address. If |host_unix| is true, this is UNIX domain
// socket path. // socket path.
ImmutableString host; StringRef host;
ImmutableString hostport; StringRef hostport;
// backend port. 0 if |host_unix| is true. // backend port. 0 if |host_unix| is true.
uint16_t port; uint16_t port;
// true if |host| contains UNIX domain socket path. // true if |host| contains UNIX domain socket path.
bool host_unix; bool host_unix;
// sni field to send remote server if TLS is enabled. // sni field to send remote server if TLS is enabled.
ImmutableString sni; StringRef sni;
std::unique_ptr<ConnectBlocker> connect_blocker; std::unique_ptr<ConnectBlocker> connect_blocker;
std::unique_ptr<LiveCheck> live_check; std::unique_ptr<LiveCheck> live_check;
...@@ -128,8 +129,18 @@ struct WeightedPri { ...@@ -128,8 +129,18 @@ struct WeightedPri {
struct SharedDownstreamAddr { struct SharedDownstreamAddr {
SharedDownstreamAddr() SharedDownstreamAddr()
: next{0}, http1_pri{}, http2_pri{}, affinity{AFFINITY_NONE} {} : balloc(1024, 1024),
next{0},
http1_pri{},
http2_pri{},
affinity{AFFINITY_NONE} {}
SharedDownstreamAddr(const SharedDownstreamAddr &) = delete;
SharedDownstreamAddr(SharedDownstreamAddr &&) = delete;
SharedDownstreamAddr &operator=(const SharedDownstreamAddr &) = delete;
SharedDownstreamAddr &operator=(SharedDownstreamAddr &&) = delete;
BlockAllocator balloc;
std::vector<DownstreamAddr> addrs; std::vector<DownstreamAddr> addrs;
// Bunch of session affinity hash. Only used if affinity == // Bunch of session affinity hash. Only used if affinity ==
// AFFINITY_IP. // AFFINITY_IP.
...@@ -162,6 +173,11 @@ struct SharedDownstreamAddr { ...@@ -162,6 +173,11 @@ struct SharedDownstreamAddr {
struct DownstreamAddrGroup { struct DownstreamAddrGroup {
DownstreamAddrGroup() : retired{false} {}; DownstreamAddrGroup() : retired{false} {};
DownstreamAddrGroup(const DownstreamAddrGroup &) = delete;
DownstreamAddrGroup(DownstreamAddrGroup &&) = delete;
DownstreamAddrGroup &operator=(const DownstreamAddrGroup &) = delete;
DownstreamAddrGroup &operator=(DownstreamAddrGroup &&) = delete;
ImmutableString pattern; ImmutableString pattern;
std::shared_ptr<SharedDownstreamAddr> shared_addr; std::shared_ptr<SharedDownstreamAddr> shared_addr;
// true if this group is no longer used for new request. If this is // true if this group is no longer used for new request. If this is
......
...@@ -131,11 +131,10 @@ bool in_attr_char(char c) { ...@@ -131,11 +131,10 @@ bool in_attr_char(char c) {
std::find(std::begin(bad), std::end(bad), c) == std::end(bad); std::find(std::begin(bad), std::end(bad), c) == std::end(bad);
} }
std::string percent_encode_token(const std::string &target) { StringRef percent_encode_token(BlockAllocator &balloc,
std::string dest; const StringRef &target) {
auto iov = make_byte_ref(balloc, target.size() * 3 + 1);
dest.resize(target.size() * 3); auto p = iov.base;
auto p = std::begin(dest);
for (auto first = std::begin(target); first != std::end(target); ++first) { for (auto first = std::begin(target); first != std::end(target); ++first) {
uint8_t c = *first; uint8_t c = *first;
...@@ -149,8 +148,10 @@ std::string percent_encode_token(const std::string &target) { ...@@ -149,8 +148,10 @@ std::string percent_encode_token(const std::string &target) {
*p++ = UPPER_XDIGITS[c >> 4]; *p++ = UPPER_XDIGITS[c >> 4];
*p++ = UPPER_XDIGITS[(c & 0x0f)]; *p++ = UPPER_XDIGITS[(c & 0x0f)];
} }
dest.resize(p - std::begin(dest));
return dest; *p = '\0';
return StringRef{iov.base, p};
} }
uint32_t hex_to_uint(char c) { uint32_t hex_to_uint(char c) {
...@@ -166,25 +167,27 @@ uint32_t hex_to_uint(char c) { ...@@ -166,25 +167,27 @@ uint32_t hex_to_uint(char c) {
return c; return c;
} }
std::string quote_string(const std::string &target) { StringRef quote_string(BlockAllocator &balloc, const StringRef &target) {
auto cnt = std::count(std::begin(target), std::end(target), '"'); auto cnt = std::count(std::begin(target), std::end(target), '"');
if (cnt == 0) { if (cnt == 0) {
return target; return make_string_ref(balloc, target);
} }
std::string res; auto iov = make_byte_ref(balloc, target.size() + cnt + 1);
res.reserve(target.size() + cnt); auto p = iov.base;
for (auto c : target) { for (auto c : target) {
if (c == '"') { if (c == '"') {
res += "\\\""; *p++ = '\\';
*p++ = '"';
} else { } else {
res += c; *p++ = c;
} }
} }
*p = '\0';
return res; return StringRef{iov.base, p};
} }
namespace { namespace {
...@@ -376,6 +379,21 @@ std::string format_hex(const unsigned char *s, size_t len) { ...@@ -376,6 +379,21 @@ std::string format_hex(const unsigned char *s, size_t len) {
return res; return res;
} }
StringRef format_hex(BlockAllocator &balloc, const StringRef &s) {
auto iov = make_byte_ref(balloc, s.size() * 2 + 1);
auto p = iov.base;
for (auto cc : s) {
uint8_t c = cc;
*p++ = LOWER_XDIGITS[c >> 4];
*p++ = LOWER_XDIGITS[c & 0xf];
}
*p = '\0';
return StringRef{iov.base, p};
}
void to_token68(std::string &base64str) { void to_token68(std::string &base64str) {
std::transform(std::begin(base64str), std::end(base64str), std::transform(std::begin(base64str), std::end(base64str),
std::begin(base64str), [](char c) { std::begin(base64str), [](char c) {
...@@ -392,22 +410,32 @@ void to_token68(std::string &base64str) { ...@@ -392,22 +410,32 @@ void to_token68(std::string &base64str) {
std::end(base64str)); std::end(base64str));
} }
void to_base64(std::string &token68str) { StringRef to_base64(BlockAllocator &balloc, const StringRef &token68str) {
std::transform(std::begin(token68str), std::end(token68str), // At most 3 padding '='
std::begin(token68str), [](char c) { auto len = token68str.size() + 3;
switch (c) { auto iov = make_byte_ref(balloc, len + 1);
case '-': auto p = iov.base;
return '+';
case '_': p = std::transform(std::begin(token68str), std::end(token68str), p,
return '/'; [](char c) {
default: switch (c) {
return c; case '-':
} return '+';
}); case '_':
if (token68str.size() & 0x3) { return '/';
token68str.append(4 - (token68str.size() & 0x3), '='); default:
return c;
}
});
auto rem = token68str.size() & 0x3;
if (rem) {
p = std::fill_n(p, 4 - rem, '=');
} }
return;
*p = '\0';
return StringRef{iov.base, p};
} }
namespace { namespace {
...@@ -1119,29 +1147,30 @@ std::string dtos(double n) { ...@@ -1119,29 +1147,30 @@ std::string dtos(double n) {
return utos(static_cast<int64_t>(n)) + "." + (f.size() == 1 ? "0" : "") + f; return utos(static_cast<int64_t>(n)) + "." + (f.size() == 1 ? "0" : "") + f;
} }
std::string make_http_hostport(const StringRef &host, uint16_t port) { StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port) {
if (port != 80 && port != 443) { if (port != 80 && port != 443) {
return make_hostport(host, port); return make_hostport(balloc, host, port);
} }
auto ipv6 = ipv6_numeric_addr(host.c_str()); auto ipv6 = ipv6_numeric_addr(host.c_str());
std::string hostport; auto iov = make_byte_ref(balloc, host.size() + (ipv6 ? 2 : 0) + 1);
hostport.resize(host.size() + (ipv6 ? 2 : 0)); auto p = iov.base;
auto p = &hostport[0];
if (ipv6) { if (ipv6) {
*p++ = '['; *p++ = '[';
} }
p = std::copy_n(host.c_str(), host.size(), p); p = std::copy(std::begin(host), std::end(host), p);
if (ipv6) { if (ipv6) {
*p++ = ']'; *p++ = ']';
} }
return hostport; *p = '\0';
return StringRef{iov.base, p};
} }
std::string make_hostport(const StringRef &host, uint16_t port) { std::string make_hostport(const StringRef &host, uint16_t port) {
...@@ -1169,6 +1198,34 @@ std::string make_hostport(const StringRef &host, uint16_t port) { ...@@ -1169,6 +1198,34 @@ std::string make_hostport(const StringRef &host, uint16_t port) {
return hostport; return hostport;
} }
StringRef make_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port) {
auto ipv6 = ipv6_numeric_addr(host.c_str());
auto serv = utos(port);
auto iov =
make_byte_ref(balloc, host.size() + (ipv6 ? 2 : 0) + 1 + serv.size());
auto p = iov.base;
if (ipv6) {
*p++ = '[';
}
p = std::copy(std::begin(host), std::end(host), p);
if (ipv6) {
*p++ = ']';
}
*p++ = ':';
p = std::copy(std::begin(serv), std::end(serv), p);
*p = '\0';
return StringRef{iov.base, p};
}
namespace { namespace {
void hexdump8(FILE *out, const uint8_t *first, const uint8_t *last) { void hexdump8(FILE *out, const uint8_t *first, const uint8_t *last) {
auto stop = std::min(first + 8, last); auto stop = std::min(first + 8, last);
......
...@@ -129,11 +129,11 @@ std::string percent_decode(InputIt first, InputIt last) { ...@@ -129,11 +129,11 @@ std::string percent_decode(InputIt first, InputIt last) {
StringRef percent_decode(BlockAllocator &balloc, const StringRef &src); StringRef percent_decode(BlockAllocator &balloc, const StringRef &src);
// Percent encode |target| if character is not in token or '%'. // Percent encode |target| if character is not in token or '%'.
std::string percent_encode_token(const std::string &target); StringRef percent_encode_token(BlockAllocator &balloc, const StringRef &target);
// Returns quotedString version of |target|. Currently, this function // Returns quotedString version of |target|. Currently, this function
// just replace '"' with '\"'. // just replace '"' with '\"'.
std::string quote_string(const std::string &target); StringRef quote_string(BlockAllocator &balloc, const StringRef &target);
std::string format_hex(const unsigned char *s, size_t len); std::string format_hex(const unsigned char *s, size_t len);
...@@ -145,6 +145,8 @@ template <size_t N> std::string format_hex(const std::array<uint8_t, N> &s) { ...@@ -145,6 +145,8 @@ template <size_t N> std::string format_hex(const std::array<uint8_t, N> &s) {
return format_hex(s.data(), s.size()); return format_hex(s.data(), s.size());
} }
StringRef format_hex(BlockAllocator &balloc, const StringRef &s);
std::string http_date(time_t t); std::string http_date(time_t t);
// Returns given time |t| from epoch in Common Log format (e.g., // Returns given time |t| from epoch in Common Log format (e.g.,
...@@ -424,7 +426,8 @@ template <typename T> std::string utox(T n) { ...@@ -424,7 +426,8 @@ template <typename T> std::string utox(T n) {
} }
void to_token68(std::string &base64str); void to_token68(std::string &base64str);
void to_base64(std::string &token68str);
StringRef to_base64(BlockAllocator &balloc, const StringRef &token68str);
void show_candidates(const char *unkopt, option *options); void show_candidates(const char *unkopt, option *options);
...@@ -630,12 +633,16 @@ std::string format_duration(double t); ...@@ -630,12 +633,16 @@ std::string format_duration(double t);
// Creates "host:port" string using given |host| and |port|. If // Creates "host:port" string using given |host| and |port|. If
// |host| is numeric IPv6 address (e.g., ::1), it is enclosed by "[" // |host| is numeric IPv6 address (e.g., ::1), it is enclosed by "["
// and "]". If |port| is 80 or 443, port part is omitted. // and "]". If |port| is 80 or 443, port part is omitted.
std::string make_http_hostport(const StringRef &host, uint16_t port); StringRef make_http_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port);
// Just like make_http_hostport(), but doesn't treat 80 and 443 // Just like make_http_hostport(), but doesn't treat 80 and 443
// specially. // specially.
std::string make_hostport(const StringRef &host, uint16_t port); std::string make_hostport(const StringRef &host, uint16_t port);
StringRef make_hostport(BlockAllocator &balloc, const StringRef &host,
uint16_t port);
// Dumps |src| of length |len| in the format similar to `hexdump -C`. // Dumps |src| of length |len| in the format similar to `hexdump -C`.
void hexdump(FILE *out, const uint8_t *src, size_t len); void hexdump(FILE *out, const uint8_t *src, size_t len);
...@@ -665,16 +672,17 @@ uint64_t get_uint64(const uint8_t *data); ...@@ -665,16 +672,17 @@ uint64_t get_uint64(const uint8_t *data);
int read_mime_types(std::map<std::string, std::string> &res, int read_mime_types(std::map<std::string, std::string> &res,
const char *filename); const char *filename);
template <typename Generator> // Fills random alpha and digit byte to the range [|first|, |last|).
std::string random_alpha_digit(Generator &gen, size_t len) { // Returns the one beyond the |last|.
std::string res; template <typename OutputIt, typename Generator>
res.reserve(len); OutputIt random_alpha_digit(OutputIt first, OutputIt last, Generator &gen) {
constexpr uint8_t s[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
std::uniform_int_distribution<> dis(0, 26 * 2 + 10 - 1); std::uniform_int_distribution<> dis(0, 26 * 2 + 10 - 1);
for (; len > 0; --len) { for (; first != last; ++first) {
res += "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[dis( *first = s[dis(gen)];
gen)];
} }
return res; return first;
} }
template <typename OutputIterator, typename CharT, size_t N> template <typename OutputIterator, typename CharT, size_t N>
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <random>
#include <CUnit/CUnit.h> #include <CUnit/CUnit.h>
...@@ -113,13 +114,12 @@ void test_util_inp_strlower(void) { ...@@ -113,13 +114,12 @@ void test_util_inp_strlower(void) {
} }
void test_util_to_base64(void) { void test_util_to_base64(void) {
std::string x = "AAA--B_"; BlockAllocator balloc(4096, 4096);
util::to_base64(x);
CU_ASSERT("AAA++B/=" == x);
x = "AAA--B_B"; CU_ASSERT("AAA++B/=" ==
util::to_base64(x); util::to_base64(balloc, StringRef::from_lit("AAA--B_")));
CU_ASSERT("AAA++B/B" == x); CU_ASSERT("AAA++B/B" ==
util::to_base64(balloc, StringRef::from_lit("AAA--B_B")));
} }
void test_util_to_token68(void) { void test_util_to_token68(void) {
...@@ -133,10 +133,15 @@ void test_util_to_token68(void) { ...@@ -133,10 +133,15 @@ void test_util_to_token68(void) {
} }
void test_util_percent_encode_token(void) { void test_util_percent_encode_token(void) {
CU_ASSERT("h2" == util::percent_encode_token("h2")); BlockAllocator balloc(4096, 4096);
CU_ASSERT("h3~" == util::percent_encode_token("h3~")); CU_ASSERT("h2" ==
CU_ASSERT("100%25" == util::percent_encode_token("100%")); util::percent_encode_token(balloc, StringRef::from_lit("h2")));
CU_ASSERT("http%202" == util::percent_encode_token("http 2")); CU_ASSERT("h3~" ==
util::percent_encode_token(balloc, StringRef::from_lit("h3~")));
CU_ASSERT("100%25" ==
util::percent_encode_token(balloc, StringRef::from_lit("100%")));
CU_ASSERT("http%202" ==
util::percent_encode_token(balloc, StringRef::from_lit("http 2")));
} }
void test_util_percent_encode_path(void) { void test_util_percent_encode_path(void) {
...@@ -169,9 +174,12 @@ void test_util_percent_decode(void) { ...@@ -169,9 +174,12 @@ void test_util_percent_decode(void) {
} }
void test_util_quote_string(void) { void test_util_quote_string(void) {
CU_ASSERT("alpha" == util::quote_string("alpha")); BlockAllocator balloc(4096, 4096);
CU_ASSERT("" == util::quote_string("")); CU_ASSERT("alpha" ==
CU_ASSERT("\\\"alpha\\\"" == util::quote_string("\"alpha\"")); util::quote_string(balloc, StringRef::from_lit("alpha")));
CU_ASSERT("" == util::quote_string(balloc, StringRef::from_lit("")));
CU_ASSERT("\\\"alpha\\\"" ==
util::quote_string(balloc, StringRef::from_lit("\"alpha\"")));
} }
void test_util_utox(void) { void test_util_utox(void) {
...@@ -494,12 +502,15 @@ void test_util_parse_config_str_list(void) { ...@@ -494,12 +502,15 @@ void test_util_parse_config_str_list(void) {
} }
void test_util_make_http_hostport(void) { void test_util_make_http_hostport(void) {
CU_ASSERT("localhost" == BlockAllocator balloc(4096, 4096);
util::make_http_hostport(StringRef::from_lit("localhost"), 80));
CU_ASSERT("localhost" == util::make_http_hostport(
balloc, StringRef::from_lit("localhost"), 80));
CU_ASSERT("[::1]" == CU_ASSERT("[::1]" ==
util::make_http_hostport(StringRef::from_lit("::1"), 443)); util::make_http_hostport(balloc, StringRef::from_lit("::1"), 443));
CU_ASSERT("localhost:3000" == CU_ASSERT(
util::make_http_hostport(StringRef::from_lit("localhost"), 3000)); "localhost:3000" ==
util::make_http_hostport(balloc, StringRef::from_lit("localhost"), 3000));
} }
void test_util_make_hostport(void) { void test_util_make_hostport(void) {
...@@ -507,6 +518,12 @@ void test_util_make_hostport(void) { ...@@ -507,6 +518,12 @@ void test_util_make_hostport(void) {
util::make_hostport(StringRef::from_lit("localhost"), 80)); util::make_hostport(StringRef::from_lit("localhost"), 80));
CU_ASSERT("[::1]:443" == CU_ASSERT("[::1]:443" ==
util::make_hostport(StringRef::from_lit("::1"), 443)); util::make_hostport(StringRef::from_lit("::1"), 443));
BlockAllocator balloc(4096, 4096);
CU_ASSERT("localhost:80" ==
util::make_hostport(balloc, StringRef::from_lit("localhost"), 80));
CU_ASSERT("[::1]:443" ==
util::make_hostport(balloc, StringRef::from_lit("::1"), 443));
} }
void test_util_strifind(void) { void test_util_strifind(void) {
...@@ -528,4 +545,27 @@ void test_util_strifind(void) { ...@@ -528,4 +545,27 @@ void test_util_strifind(void) {
StringRef::from_lit("http1"))); StringRef::from_lit("http1")));
} }
void test_util_random_alpha_digit(void) {
std::random_device rd;
std::mt19937 gen(rd());
std::array<uint8_t, 19> data;
auto p = util::random_alpha_digit(std::begin(data), std::end(data), gen);
CU_ASSERT(std::end(data) == p);
for (auto b : data) {
CU_ASSERT(('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z') ||
('0' <= b && b <= '9'));
}
}
void test_util_format_hex(void) {
BlockAllocator balloc(4096, 4096);
CU_ASSERT("0ff0" ==
util::format_hex(balloc, StringRef::from_lit("\x0f\xf0")));
CU_ASSERT("" == util::format_hex(balloc, StringRef::from_lit("")));
}
} // namespace shrpx } // namespace shrpx
...@@ -62,6 +62,8 @@ void test_util_parse_config_str_list(void); ...@@ -62,6 +62,8 @@ void test_util_parse_config_str_list(void);
void test_util_make_http_hostport(void); void test_util_make_http_hostport(void);
void test_util_make_hostport(void); void test_util_make_hostport(void);
void test_util_strifind(void); void test_util_strifind(void);
void test_util_random_alpha_digit(void);
void test_util_format_hex(void);
} // namespace shrpx } // namespace shrpx
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment