Commit ccfa13cd authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

nghttpx: Rewrite location header field

We thought that this kind of rewrite can be achieved by the configuration
of the backend severs, but in some configuration, however, it may get
complicated. So we decided to implement at least location rewrite in
nghttpx.

This commit also contains a fix to the bug which prevents the http2
backend request from concatenating header fields with the same value.
parent bb70cdf6
......@@ -115,7 +115,7 @@ void sanitize_header_value(std::string& s, size_t offset)
}
}
void copy_url_component(std::string& dest, http_parser_url *u, int field,
void copy_url_component(std::string& dest, const http_parser_url *u, int field,
const char* url)
{
if(u->field_set & (1 << field)) {
......@@ -439,6 +439,74 @@ void dump_nv(FILE *out, const nghttp2_nv *nva, size_t nvlen)
fflush(out);
}
std::string rewrite_location_uri(const std::string& uri,
const http_parser_url& u,
const std::string& request_host,
const std::string& upstream_scheme,
uint16_t upstream_port,
uint16_t downstream_port)
{
// We just rewrite host and optionally port. We don't rewrite https
// link. Not sure it happens in practice.
if(u.field_set & (1 << UF_SCHEMA)) {
auto field = &u.field_data[UF_SCHEMA];
if(!util::streq("http", &uri[field->off], field->len)) {
return "";
}
}
if((u.field_set & (1 << UF_HOST)) == 0) {
return "";
}
std::string host;
copy_url_component(host, &u, UF_HOST, uri.c_str());
if(u.field_set & (1 << UF_PORT)) {
host += ":";
host += util::utos(u.port);
if(host != request_host) {
// :authority or host have "host", but host in location header
// field may have "host:port".
auto field = &u.field_data[UF_HOST];
if(!util::streq(request_host.c_str(), request_host.size(),
&uri[field->off], field->len) ||
downstream_port != u.port) {
return "";
}
}
} else if(host != request_host) {
return "";
}
std::string res = upstream_scheme;
res += "://";
auto field = &u.field_data[UF_HOST];
res.append(&uri[field->off], field->len);
if(upstream_scheme == "http") {
if(upstream_port != 80) {
res += ":";
res += util::utos(upstream_port);
}
} else if(upstream_scheme == "https") {
if(upstream_port != 443) {
res += ":";
res += util::utos(upstream_port);
}
}
if(u.field_set & (1 << UF_PATH)) {
field = &u.field_data[UF_PATH];
res.append(&uri[field->off], field->len);
}
if(u.field_set & (1 << UF_QUERY)) {
field = &u.field_data[UF_QUERY];
res += "?";
res.append(&uri[field->off], field->len);
}
if(u.field_set & (1 << UF_FRAGMENT)) {
field = &u.field_data[UF_FRAGMENT];
res += "#";
res.append(&uri[field->off], field->len);
}
return res;
}
} // namespace http2
} // namespace nghttp2
......@@ -55,7 +55,7 @@ void sanitize_header_value(std::string& s, size_t offset);
// Copies the |field| component value from |u| and |url| to the
// |dest|. If |u| does not have |field|, then this function does
// nothing.
void copy_url_component(std::string& dest, http_parser_url *u, int field,
void copy_url_component(std::string& dest, const http_parser_url *u, int field,
const char* url);
// Returns true if the header field |name| with length |namelen| bytes
......@@ -170,6 +170,24 @@ void dump_nv(FILE *out, const char **nv);
// Dumps name/value pairs in |nva| to |out|.
void dump_nv(FILE *out, const nghttp2_nv *nva, size_t nvlen);
// Rewrites redirection URI which usually appears in location header
// field. The |uri| is the URI in the location header field. The |u|
// stores the result of parsed |uri|. The |request_host| is the host
// or :authority header field value in the request. The
// |upstream_scheme| is either "https" or "http" in the upstream
// interface. The |downstream_port| is the port in the downstream
// connection.
//
// This function returns the new rewritten URI on success. If the
// location URI is not subject to the rewrite, this function returns
// emtpy string.
std::string rewrite_location_uri(const std::string& uri,
const http_parser_url& u,
const std::string& request_host,
const std::string& upstream_scheme,
uint16_t upstream_port,
uint16_t downstream_port);
} // namespace http2
} // namespace nghttp2
......
......@@ -30,6 +30,8 @@
#include <CUnit/CUnit.h>
#include "http-parser/http_parser.h"
#include "http2.h"
#include "util.h"
......@@ -222,4 +224,52 @@ void test_http2_check_header_value(void)
CU_ASSERT(!http2::check_header_value(&nv3));
}
namespace {
void check_rewrite_location_uri(const std::string& new_uri,
const std::string& uri,
const std::string& req_host,
const std::string& upstream_scheme,
uint16_t upstream_port,
uint16_t downstream_port)
{
http_parser_url u;
CU_ASSERT(0 == http_parser_parse_url(uri.c_str(), uri.size(), 0, &u));
CU_ASSERT(new_uri ==
http2::rewrite_location_uri(uri, u, req_host,
upstream_scheme, upstream_port,
downstream_port));
}
} // namespace
void test_http2_rewrite_location_uri(void)
{
check_rewrite_location_uri("https://localhost:3000/alpha?bravo#charlie",
"http://localhost:3001/alpha?bravo#charlie",
"localhost:3001", "https", 3000, 3001);
check_rewrite_location_uri("https://localhost/",
"http://localhost:3001/",
"localhost:3001", "https", 443, 3001);
check_rewrite_location_uri("http://localhost/",
"http://localhost:3001/",
"localhost:3001", "http", 80, 3001);
check_rewrite_location_uri("http://localhost:443/",
"http://localhost:3001/",
"localhost:3001", "http", 443, 3001);
check_rewrite_location_uri("https://localhost:80/",
"http://localhost:3001/",
"localhost:3001", "https", 80, 3001);
check_rewrite_location_uri("",
"http://localhost:3001/",
"127.0.0.1", "https", 3000, 3001);
check_rewrite_location_uri("https://localhost:3000/",
"http://localhost:3001/",
"localhost", "https", 3000, 3001);
check_rewrite_location_uri("",
"https://localhost:3001/",
"localhost", "https", 3000, 3001);
check_rewrite_location_uri("https://localhost:3000/",
"http://localhost/",
"localhost", "https", 3000, 80);
}
} // namespace shrpx
......@@ -36,6 +36,7 @@ void test_http2_concat_norm_headers(void);
void test_http2_copy_norm_headers_to_nva(void);
void test_http2_build_http1_headers_from_norm_headers(void);
void test_http2_check_header_value(void);
void test_http2_rewrite_location_uri(void);
} // namespace shrpx
......
......@@ -86,6 +86,8 @@ int main(int argc, char* argv[])
shrpx::test_http2_build_http1_headers_from_norm_headers) ||
!CU_add_test(pSuite, "http2_check_header_value",
shrpx::test_http2_check_header_value) ||
!CU_add_test(pSuite, "http2_rewrite_location_uri",
shrpx::test_http2_rewrite_location_uri) ||
!CU_add_test(pSuite, "downstream_normalize_request_headers",
shrpx::test_downstream_normalize_request_headers) ||
!CU_add_test(pSuite, "downstream_normalize_response_headers",
......@@ -98,6 +100,8 @@ int main(int argc, char* argv[])
shrpx::test_downstream_crumble_request_cookie) ||
!CU_add_test(pSuite, "downstream_assemble_request_cookie",
shrpx::test_downstream_assemble_request_cookie) ||
!CU_add_test(pSuite, "downstream_rewrite_norm_location_response_header",
shrpx::test_downstream_rewrite_norm_location_response_header) ||
!CU_add_test(pSuite, "util_streq", shrpx::test_util_streq) ||
!CU_add_test(pSuite, "util_inp_strlower",
shrpx::test_util_inp_strlower) ||
......
......@@ -460,4 +460,13 @@ bool ClientHandler::get_http2_upgrade_allowed() const
return !ssl_;
}
std::string ClientHandler::get_upstream_scheme() const
{
if(ssl_) {
return "https";
} else {
return "http";
}
}
} // namespace shrpx
......@@ -75,6 +75,8 @@ public:
// terminated. This function returns 0 if it succeeds, or -1.
int perform_http2_upgrade(HttpsUpstream *http);
bool get_http2_upgrade_allowed() const;
// Returns upstream scheme, either "http" or "https"
std::string get_upstream_scheme() const;
private:
std::set<DownstreamConnection*> dconn_pool_;
std::unique_ptr<Upstream> upstream_;
......
......@@ -26,6 +26,8 @@
#include <cassert>
#include "http-parser/http_parser.h"
#include "shrpx_upstream.h"
#include "shrpx_client_handler.h"
#include "shrpx_config.h"
......@@ -174,6 +176,19 @@ Headers::const_iterator get_norm_header(const Headers& headers,
}
} // namespace
namespace {
Headers::iterator get_norm_header(Headers& headers,
const std::string& name)
{
auto i = std::lower_bound(std::begin(headers), std::end(headers),
std::make_pair(name, std::string()), name_less);
if(i != std::end(headers) && (*i).first == name) {
return i;
}
return std::end(headers);
}
} // namespace
const Headers& Downstream::get_request_headers() const
{
return request_headers_;
......@@ -253,6 +268,11 @@ Headers::const_iterator Downstream::get_norm_request_header
return get_norm_header(request_headers_, name);
}
void Downstream::concat_norm_request_headers()
{
request_headers_ = http2::concat_norm_headers(std::move(request_headers_));
}
void Downstream::add_request_header(std::string name, std::string value)
{
request_header_key_prev_ = true;
......@@ -467,6 +487,42 @@ Headers::const_iterator Downstream::get_norm_response_header
return get_norm_header(response_headers_, name);
}
void Downstream::rewrite_norm_location_response_header
(const std::string& upstream_scheme,
uint16_t upstream_port,
uint16_t downstream_port)
{
auto hd = get_norm_header(response_headers_, "location");
if(hd == std::end(response_headers_)) {
return;
}
http_parser_url u;
int rv = http_parser_parse_url((*hd).second.c_str(), (*hd).second.size(),
0, &u);
if(rv != 0) {
return;
}
std::string new_uri;
if(!request_http2_authority_.empty()) {
new_uri = http2::rewrite_location_uri((*hd).second, u,
request_http2_authority_,
upstream_scheme, upstream_port,
downstream_port);
}
if(new_uri.empty()) {
auto host = get_norm_request_header("host");
if(host == std::end(request_headers_)) {
return;
}
new_uri = http2::rewrite_location_uri((*hd).second, u, (*host).second,
upstream_scheme, upstream_port,
downstream_port);
}
if(!new_uri.empty()) {
(*hd).second = std::move(new_uri);
}
}
void Downstream::add_response_header(std::string name, std::string value)
{
response_header_key_prev_ = true;
......
......@@ -94,6 +94,10 @@ public:
// called after calling normalize_request_headers().
Headers::const_iterator get_norm_request_header
(const std::string& name) const;
// Concatenates request header fields with same name by NULL as
// delimiter. See http2::concat_norm_headers(). This function must
// be called after calling normalize_request_headers().
void concat_norm_request_headers();
void add_request_header(std::string name, std::string value);
void set_last_request_header_value(std::string value);
......@@ -151,6 +155,13 @@ public:
// called after calling normalize_response_headers().
Headers::const_iterator get_norm_response_header
(const std::string& name) const;
// Rewrites the location response header field. This function must
// be called after calling normalize_response_headers() and
// normalize_request_headers().
void rewrite_norm_location_response_header
(const std::string& upstream_scheme,
uint16_t upstream_port,
uint16_t downstream_port);
void add_response_header(std::string name, std::string value);
void set_last_response_header_value(std::string value);
......
......@@ -146,4 +146,24 @@ void test_downstream_assemble_request_cookie(void)
}
void test_downstream_rewrite_norm_location_response_header(void)
{
{
Downstream d(nullptr, 0, 0);
d.add_request_header("host", "localhost:3000");
d.add_response_header("location", "http://localhost:3000/");
d.rewrite_norm_location_response_header("https", 443, 3000);
auto location = d.get_norm_response_header("location");
CU_ASSERT("https://localhost/" == (*location).second);
}
{
Downstream d(nullptr, 0, 0);
d.set_request_http2_authority("localhost");
d.add_response_header("location", "http://localhost/");
d.rewrite_norm_location_response_header("https", 443, 80);
auto location = d.get_norm_response_header("location");
CU_ASSERT("https://localhost/" == (*location).second);
}
}
} // namespace shrpx
......@@ -33,6 +33,7 @@ void test_downstream_get_norm_request_header(void);
void test_downstream_get_norm_response_header(void);
void test_downstream_crumble_request_cookie(void);
void test_downstream_assemble_request_cookie(void);
void test_downstream_rewrite_norm_location_response_header(void);
} // namespace shrpx
......
......@@ -236,7 +236,7 @@ int Http2DownstreamConnection::push_request_headers()
downstream_->crumble_request_cookie();
}
downstream_->normalize_request_headers();
downstream_->concat_norm_response_headers();
downstream_->concat_norm_request_headers();
auto end_headers = std::end(downstream_->get_request_headers());
// 6 means:
......
......@@ -945,6 +945,9 @@ int Http2Upstream::on_downstream_header_complete(Downstream *downstream)
DLOG(INFO, downstream) << "HTTP response header completed";
}
downstream->normalize_response_headers();
downstream->rewrite_norm_location_response_header
(get_client_handler()->get_upstream_scheme(), get_config()->port,
get_config()->downstream_port);
downstream->concat_norm_response_headers();
auto end_headers = std::end(downstream->get_response_headers());
size_t nheader = downstream->get_response_headers().size();
......
......@@ -656,6 +656,9 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream)
hdrs += http2::get_status_string(downstream->get_response_http_status());
hdrs += "\r\n";
downstream->normalize_response_headers();
downstream->rewrite_norm_location_response_header
(get_client_handler()->get_upstream_scheme(), get_config()->port,
get_config()->downstream_port);
auto end_headers = std::end(downstream->get_response_headers());
http2::build_http1_headers_from_norm_headers
(hdrs, downstream->get_response_headers());
......
......@@ -839,6 +839,10 @@ int SpdyUpstream::on_downstream_header_complete(Downstream *downstream)
if(LOG_ENABLED(INFO)) {
DLOG(INFO, downstream) << "HTTP response header completed";
}
downstream->normalize_response_headers();
downstream->rewrite_norm_location_response_header
(get_client_handler()->get_upstream_scheme(), get_config()->port,
get_config()->downstream_port);
size_t nheader = downstream->get_response_headers().size();
// 6 means :status, :version and possible via header field.
auto nv = util::make_unique<const char*[]>(nheader * 2 + 6 + 1);
......
......@@ -172,24 +172,6 @@ bool strieq(const char *a, const uint8_t *b, size_t bn)
return !*a && b == blast;
}
bool streq(const char *a, const uint8_t *b, size_t bn)
{
if(!a || !b) {
return false;
}
const uint8_t *blast = b + bn;
for(; *a && b != blast && *a == *b; ++a, ++b);
return !*a && b == blast;
}
bool streq(const uint8_t *a, size_t alen, const uint8_t *b, size_t blen)
{
if(alen != blen) {
return false;
}
return memcmp(a, b, alen) == 0;
}
int strcompare(const char *a, const uint8_t *b, size_t bn)
{
assert(a && b);
......
......@@ -299,9 +299,25 @@ bool strieq(const char *a, const char *b);
bool strieq(const char *a, const uint8_t *b, size_t n);
bool streq(const char *a, const uint8_t *b, size_t bn);
template<typename A, typename B>
bool streq(const A *a, const B *b, size_t bn)
{
if(!a || !b) {
return false;
}
auto blast = b + bn;
for(; *a && b != blast && *a == *b; ++a, ++b);
return !*a && b == blast;
}
bool streq(const uint8_t *a, size_t alen, const uint8_t *b, size_t blen);
template<typename A, typename B>
bool streq(const A *a, size_t alen, const B *b, size_t blen)
{
if(alen != blen) {
return false;
}
return memcmp(a, b, alen) == 0;
}
bool strifind(const char *a, const char *b);
......
......@@ -53,7 +53,9 @@ void test_util_streq(void)
(const uint8_t*)"alpha", 4));
CU_ASSERT(!util::streq((const uint8_t*)"alpha", 5,
(const uint8_t*)"alphA", 5));
CU_ASSERT(util::streq(nullptr, 0, nullptr, 0));
char *a = nullptr;
char *b = nullptr;
CU_ASSERT(util::streq(a, 0, b, 0));
}
void test_util_inp_strlower(void)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment