Commit dbbf3a4a authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

nghttpx: Refactor TLS hostname match

parent f25fd09b
...@@ -69,6 +69,8 @@ int main(int argc, char *argv[]) { ...@@ -69,6 +69,8 @@ int main(int argc, char *argv[]) {
shrpx::test_shrpx_ssl_create_lookup_tree) || shrpx::test_shrpx_ssl_create_lookup_tree) ||
!CU_add_test(pSuite, "ssl_cert_lookup_tree_add_cert_from_file", !CU_add_test(pSuite, "ssl_cert_lookup_tree_add_cert_from_file",
shrpx::test_shrpx_ssl_cert_lookup_tree_add_cert_from_file) || shrpx::test_shrpx_ssl_cert_lookup_tree_add_cert_from_file) ||
!CU_add_test(pSuite, "ssl_tls_hostname_match",
shrpx::test_shrpx_ssl_tls_hostname_match) ||
!CU_add_test(pSuite, "http2_add_header", shrpx::test_http2_add_header) || !CU_add_test(pSuite, "http2_add_header", shrpx::test_http2_add_header) ||
!CU_add_test(pSuite, "http2_get_header", shrpx::test_http2_get_header) || !CU_add_test(pSuite, "http2_get_header", shrpx::test_http2_get_header) ||
!CU_add_test(pSuite, "http2_copy_headers_to_nva", !CU_add_test(pSuite, "http2_copy_headers_to_nva",
......
...@@ -780,27 +780,35 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr, ...@@ -780,27 +780,35 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
return new ClientHandler(worker, fd, ssl, host, service); return new ClientHandler(worker, fd, ssl, host, service);
} }
namespace { bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname,
bool tls_hostname_match(const char *pattern, const char *hostname) { size_t hlen) {
const char *ptWildcard = strchr(pattern, '*'); auto pend = pattern + plen;
if (ptWildcard == nullptr) { auto ptWildcard = std::find(pattern, pend, '*');
return util::strieq(pattern, hostname); if (ptWildcard == pend) {
return util::strieq(pattern, plen, hostname, hlen);
} }
const char *ptLeftLabelEnd = strchr(pattern, '.');
bool wildcardEnabled = true; auto ptLeftLabelEnd = std::find(pattern, pend, '.');
auto wildcardEnabled = true;
// Do case-insensitive match. At least 2 dots are required to enable // Do case-insensitive match. At least 2 dots are required to enable
// wildcard match. Also wildcard must be in the left-most label. // wildcard match. Also wildcard must be in the left-most label.
// Don't attempt to match a presented identifier where the wildcard // Don't attempt to match a presented identifier where the wildcard
// character is embedded within an A-label. // character is embedded within an A-label.
if (ptLeftLabelEnd == 0 || strchr(ptLeftLabelEnd + 1, '.') == 0 || if (ptLeftLabelEnd == pend ||
ptLeftLabelEnd < ptWildcard || util::istarts_with(pattern, "xn--")) { std::find(ptLeftLabelEnd + 1, pend, '.') == pend ||
ptLeftLabelEnd < ptWildcard ||
util::istarts_with(pattern, plen, "xn--")) {
wildcardEnabled = false; wildcardEnabled = false;
} }
if (!wildcardEnabled) { if (!wildcardEnabled) {
return util::strieq(pattern, hostname); return util::strieq(pattern, plen, hostname, hlen);
} }
const char *hnLeftLabelEnd = strchr(hostname, '.');
if (hnLeftLabelEnd == 0 || !util::strieq(ptLeftLabelEnd, hnLeftLabelEnd)) { auto hend = hostname + hlen;
auto hnLeftLabelEnd = std::find(hostname, hend, '.');
if (hnLeftLabelEnd == hend ||
!util::strieq(ptLeftLabelEnd, pend, hnLeftLabelEnd, hend)) {
return false; return false;
} }
// Perform wildcard match. Here '*' must match at least one // Perform wildcard match. Here '*' must match at least one
...@@ -812,107 +820,143 @@ bool tls_hostname_match(const char *pattern, const char *hostname) { ...@@ -812,107 +820,143 @@ bool tls_hostname_match(const char *pattern, const char *hostname) {
util::iends_with(hostname, hnLeftLabelEnd, ptWildcard + 1, util::iends_with(hostname, hnLeftLabelEnd, ptWildcard + 1,
ptLeftLabelEnd); ptLeftLabelEnd);
} }
} // namespace
namespace { namespace {
int verify_hostname(const char *hostname, const Address *addr, ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) {
const std::vector<std::string> &dns_names, auto subjectname = X509_get_subject_name(cert);
const std::vector<std::string> &ip_addrs, if (!subjectname) {
const std::string &common_name) { LOG(WARN) << "Could not get X509 name object from the certificate.";
if (util::numeric_host(hostname)) { return -1;
if (ip_addrs.empty()) { }
return util::strieq(common_name.c_str(), hostname) ? 0 : -1; int lastpos = -1;
} for (;;) {
const void *saddr; lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos);
switch (addr->su.storage.ss_family) { if (lastpos == -1) {
case AF_INET:
saddr = &addr->su.in.sin_addr;
break;
case AF_INET6:
saddr = &addr->su.in6.sin6_addr;
break; break;
default:
return -1;
} }
for (size_t i = 0; i < ip_addrs.size(); ++i) { auto entry = X509_NAME_get_entry(subjectname, lastpos);
if (addr->len == ip_addrs[i].size() &&
memcmp(saddr, ip_addrs[i].c_str(), addr->len) == 0) { auto outlen = ASN1_STRING_to_UTF8(out_ptr, X509_NAME_ENTRY_get_data(entry));
return 0; if (outlen < 0) {
} continue;
} }
} else { if (std::find(*out_ptr, *out_ptr + outlen, '\0') != *out_ptr + outlen) {
if (dns_names.empty()) { // Embedded NULL is not permitted.
return tls_hostname_match(common_name.c_str(), hostname) ? 0 : -1; continue;
} }
for (size_t i = 0; i < dns_names.size(); ++i) { return outlen;
if (tls_hostname_match(dns_names[i].c_str(), hostname)) { }
return -1;
}
} // namespace
namespace {
int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen,
const Address *addr) {
const void *saddr;
switch (addr->su.storage.ss_family) {
case AF_INET:
saddr = &addr->su.in.sin_addr;
break;
case AF_INET6:
saddr = &addr->su.in6.sin6_addr;
break;
default:
return -1;
}
auto altnames = static_cast<GENERAL_NAMES *>(
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
if (altnames) {
auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
auto n = sk_GENERAL_NAME_num(altnames);
for (size_t i = 0; i < n; ++i) {
auto altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type != GEN_IPADD) {
continue;
}
auto ip_addr = altname->d.iPAddress->data;
if (!ip_addr) {
continue;
}
auto ip_addrlen = altname->d.iPAddress->length;
if (addr->len == ip_addrlen && memcmp(saddr, ip_addr, ip_addrlen) == 0) {
return 0; return 0;
} }
} }
} }
unsigned char *cn;
auto cnlen = get_common_name(&cn, cert);
if (cnlen == -1) {
return -1;
}
// cn is not NULL terminated
auto rv = util::streq(hostname, hlen, cn, cnlen);
OPENSSL_free(cn);
if (rv) {
return 0;
}
return -1; return -1;
} }
} // namespace } // namespace
void get_altnames(X509 *cert, std::vector<std::string> &dns_names, namespace {
std::vector<std::string> &ip_addrs, int verify_hostname(X509 *cert, const char *hostname, size_t hlen,
std::string &common_name) { const Address *addr) {
GENERAL_NAMES *altnames = static_cast<GENERAL_NAMES *>( if (util::numeric_host(hostname)) {
return verify_numeric_hostname(cert, hostname, hlen, addr);
}
auto altnames = static_cast<GENERAL_NAMES *>(
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
if (altnames) { if (altnames) {
auto altnames_deleter = defer(GENERAL_NAMES_free, altnames); auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
size_t n = sk_GENERAL_NAME_num(altnames); auto n = sk_GENERAL_NAME_num(altnames);
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
const GENERAL_NAME *altname = sk_GENERAL_NAME_value(altnames, i); auto altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type == GEN_DNS) { if (altname->type != GEN_DNS) {
const char *name; continue;
name = reinterpret_cast<char *>(ASN1_STRING_data(altname->d.ia5)); }
if (!name) {
continue; auto name = reinterpret_cast<char *>(ASN1_STRING_data(altname->d.ia5));
} if (!name) {
size_t len = ASN1_STRING_length(altname->d.ia5); continue;
if (std::find(name, name + len, '\0') != name + len) { }
// Embedded NULL is not permitted.
continue; auto len = ASN1_STRING_length(altname->d.ia5);
} if (std::find(name, name + len, '\0') != name + len) {
dns_names.push_back(std::string(name, len)); // Embedded NULL is not permitted.
} else if (altname->type == GEN_IPADD) { continue;
const unsigned char *ip_addr = altname->d.iPAddress->data; }
if (!ip_addr) {
continue; if (tls_hostname_match(name, len, hostname, hlen)) {
} return 0;
size_t len = altname->d.iPAddress->length;
ip_addrs.push_back(
std::string(reinterpret_cast<const char *>(ip_addr), len));
} }
} }
} }
X509_NAME *subjectname = X509_get_subject_name(cert);
if (!subjectname) { unsigned char *cn;
LOG(WARN) << "Could not get X509 name object from the certificate."; auto cnlen = get_common_name(&cn, cert);
return; if (cnlen == -1) {
return -1;
} }
int lastpos = -1;
while (1) { auto rv = util::strieq(hostname, hlen, cn, cnlen);
lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos); OPENSSL_free(cn);
if (lastpos == -1) {
break; if (rv) {
} return 0;
X509_NAME_ENTRY *entry = X509_NAME_get_entry(subjectname, lastpos);
unsigned char *out;
int outlen = ASN1_STRING_to_UTF8(&out, X509_NAME_ENTRY_get_data(entry));
if (outlen < 0) {
continue;
}
if (std::find(out, out + outlen, '\0') != out + outlen) {
// Embedded NULL is not permitted.
continue;
}
common_name.assign(&out[0], &out[outlen]);
OPENSSL_free(out);
break;
} }
return -1;
} }
} // namespace
int check_cert(SSL *ssl, const DownstreamAddr *addr) { int check_cert(SSL *ssl, const DownstreamAddr *addr) {
auto cert = SSL_get_peer_certificate(ssl); auto cert = SSL_get_peer_certificate(ssl);
...@@ -921,21 +965,16 @@ int check_cert(SSL *ssl, const DownstreamAddr *addr) { ...@@ -921,21 +965,16 @@ int check_cert(SSL *ssl, const DownstreamAddr *addr) {
return -1; return -1;
} }
auto cert_deleter = defer(X509_free, cert); auto cert_deleter = defer(X509_free, cert);
long verify_res = SSL_get_verify_result(ssl); auto verify_res = SSL_get_verify_result(ssl);
if (verify_res != X509_V_OK) { if (verify_res != X509_V_OK) {
LOG(ERROR) << "Certificate verification failed: " LOG(ERROR) << "Certificate verification failed: "
<< X509_verify_cert_error_string(verify_res); << X509_verify_cert_error_string(verify_res);
return -1; return -1;
} }
std::string common_name;
std::vector<std::string> dns_names;
std::vector<std::string> ip_addrs;
get_altnames(cert, dns_names, ip_addrs, common_name);
auto hostname = get_config()->backend_tls_sni_name auto hostname = get_config()->backend_tls_sni_name
? get_config()->backend_tls_sni_name.get() ? get_config()->backend_tls_sni_name.get()
: addr->host.get(); : addr->host.get();
if (verify_hostname(hostname, &addr->addr, dns_names, ip_addrs, if (verify_hostname(cert, hostname, strlen(hostname), &addr->addr) != 0) {
common_name) != 0) {
LOG(ERROR) << "Certificate verification failed: hostname does not match"; LOG(ERROR) << "Certificate verification failed: hostname does not match";
return -1; return -1;
} }
...@@ -969,7 +1008,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname, ...@@ -969,7 +1008,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname,
// some restrictions for wildcard hostname. We just ignore // some restrictions for wildcard hostname. We just ignore
// these rules here but do the proper check when we do the // these rules here but do the proper check when we do the
// match. // match.
node->wildcard_certs.emplace_back(hostname, ssl_ctx); node->wildcard_certs.push_back({ssl_ctx, hostname, len});
return; return;
} }
...@@ -986,7 +1025,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname, ...@@ -986,7 +1025,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname,
new_node->ssl_ctx = ssl_ctx; new_node->ssl_ctx = ssl_ctx;
} else { } else {
new_node->ssl_ctx = nullptr; new_node->ssl_ctx = nullptr;
new_node->wildcard_certs.emplace_back(hostname, ssl_ctx); new_node->wildcard_certs.push_back({ssl_ctx, hostname, len});
} }
node->next.push_back(std::move(new_node)); node->next.push_back(std::move(new_node));
return; return;
...@@ -1073,9 +1112,11 @@ SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname, ...@@ -1073,9 +1112,11 @@ SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname,
// one character. // one character.
return nullptr; return nullptr;
} }
for (const auto &wildcert : node->wildcard_certs) { for (const auto &wildcert : node->wildcard_certs) {
if (tls_hostname_match(wildcert.first, hostname)) { if (tls_hostname_match(wildcert.hostname, wildcert.hostnamelen, hostname,
return wildcert.second; len)) {
return wildcert.ssl_ctx;
} }
} }
auto c = util::lowcase(hostname[j]); auto c = util::lowcase(hostname[j]);
...@@ -1111,14 +1152,43 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx, ...@@ -1111,14 +1152,43 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx,
return -1; return -1;
} }
auto cert_deleter = defer(X509_free, cert); auto cert_deleter = defer(X509_free, cert);
std::string common_name;
std::vector<std::string> dns_names; auto altnames = static_cast<GENERAL_NAMES *>(
std::vector<std::string> ip_addrs; X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
get_altnames(cert, dns_names, ip_addrs, common_name); if (altnames) {
for (auto &dns_name : dns_names) { auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
lt->add_cert(ssl_ctx, dns_name.c_str(), dns_name.size()); auto n = sk_GENERAL_NAME_num(altnames);
} for (size_t i = 0; i < n; ++i) {
lt->add_cert(ssl_ctx, common_name.c_str(), common_name.size()); auto altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type != GEN_DNS) {
continue;
}
auto name = reinterpret_cast<char *>(ASN1_STRING_data(altname->d.ia5));
if (!name) {
continue;
}
auto len = ASN1_STRING_length(altname->d.ia5);
if (std::find(name, name + len, '\0') != name + len) {
// Embedded NULL is not permitted.
continue;
}
lt->add_cert(ssl_ctx, name, len);
}
}
unsigned char *cn;
auto cnlen = get_common_name(&cn, cert);
if (cnlen == -1) {
return 0;
}
lt->add_cert(ssl_ctx, reinterpret_cast<char *>(cn), cnlen);
OPENSSL_free(cn);
return 0; return 0;
} }
......
...@@ -108,10 +108,16 @@ void get_altnames(X509 *cert, std::vector<std::string> &dns_names, ...@@ -108,10 +108,16 @@ void get_altnames(X509 *cert, std::vector<std::string> &dns_names,
// them. If there is a match, its SSL_CTX is returned. If none // them. If there is a match, its SSL_CTX is returned. If none
// matches, query is continued to the next character. // matches, query is continued to the next character.
struct WildcardCert {
SSL_CTX *ssl_ctx;
char *hostname;
size_t hostnamelen;
};
struct CertNode { struct CertNode {
// list of wildcard domain name and its SSL_CTX pair, the wildcard // list of wildcard domain name and its SSL_CTX pair, the wildcard
// '*' appears in this position. // '*' appears in this position.
std::vector<std::pair<char *, SSL_CTX *>> wildcard_certs; std::vector<WildcardCert> wildcard_certs;
// Next CertNode index of CertLookupTree::nodes // Next CertNode index of CertLookupTree::nodes
std::vector<std::unique_ptr<CertNode>> next; std::vector<std::unique_ptr<CertNode>> next;
// SSL_CTX for exact match // SSL_CTX for exact match
...@@ -198,6 +204,13 @@ SSL *create_ssl(SSL_CTX *ssl_ctx); ...@@ -198,6 +204,13 @@ SSL *create_ssl(SSL_CTX *ssl_ctx);
// Returns true if SSL/TLS is enabled on downstream // Returns true if SSL/TLS is enabled on downstream
bool downstream_tls_enabled(); bool downstream_tls_enabled();
// Performs TLS hostname match. |pattern| of length |plen| can
// contain wildcard character '*', which matches prefix of target
// hostname. There are several restrictions to make wildcard work.
// The matching algorithm is based on RFC 6125.
bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname,
size_t hlen);
} // namespace ssl } // namespace ssl
} // namespace shrpx } // namespace shrpx
......
...@@ -115,4 +115,37 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) { ...@@ -115,4 +115,37 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) {
SSL_CTX_free(ssl_ctx); SSL_CTX_free(ssl_ctx);
} }
template <size_t N, size_t M>
bool tls_hostname_match_wrapper(const char(&pattern)[N],
const char(&hostname)[M]) {
return ssl::tls_hostname_match(pattern, N, hostname, M);
}
void test_shrpx_ssl_tls_hostname_match(void) {
CU_ASSERT(tls_hostname_match_wrapper("example.com", "example.com"));
CU_ASSERT(tls_hostname_match_wrapper("example.com", "EXAMPLE.com"));
// check wildcard
CU_ASSERT(tls_hostname_match_wrapper("*.example.com", "www.example.com"));
CU_ASSERT(tls_hostname_match_wrapper("*w.example.com", "www.example.com"));
CU_ASSERT(tls_hostname_match_wrapper("www*.example.com", "www1.example.com"));
CU_ASSERT(
tls_hostname_match_wrapper("www*.example.com", "WWW12.EXAMPLE.com"));
// at least 2 dots are required after '*'
CU_ASSERT(!tls_hostname_match_wrapper("*.com", "example.com"));
CU_ASSERT(!tls_hostname_match_wrapper("*", "example.com"));
// '*' must be in left most label
CU_ASSERT(
!tls_hostname_match_wrapper("blog.*.example.com", "blog.my.example.com"));
// prefix is wrong
CU_ASSERT(
!tls_hostname_match_wrapper("client*.example.com", "server.example.com"));
// '*' must match at least one character
CU_ASSERT(!tls_hostname_match_wrapper("www*.example.com", "www.example.com"));
CU_ASSERT(!tls_hostname_match_wrapper("example.com", "nghttp2.org"));
CU_ASSERT(!tls_hostname_match_wrapper("www.example.com", "example.com"));
CU_ASSERT(!tls_hostname_match_wrapper("example.com", "www.example.com"));
}
} // namespace shrpx } // namespace shrpx
...@@ -33,6 +33,7 @@ namespace shrpx { ...@@ -33,6 +33,7 @@ namespace shrpx {
void test_shrpx_ssl_create_lookup_tree(void); void test_shrpx_ssl_create_lookup_tree(void);
void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void); void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void);
void test_shrpx_ssl_tls_hostname_match(void);
} // namespace shrpx } // namespace shrpx
......
...@@ -258,6 +258,15 @@ bool strieq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) { ...@@ -258,6 +258,15 @@ bool strieq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) {
return std::equal(a, a + alen, b, CaseCmp()); return std::equal(a, a + alen, b, CaseCmp());
} }
template <typename InputIt1, typename InputIt2>
bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) {
if (std::distance(first1, last1) != std::distance(first2, last2)) {
return false;
}
return std::equal(first1, last1, first2, CaseCmp());
}
inline bool strieq(const std::string &a, const std::string &b) { inline bool strieq(const std::string &a, const std::string &b) {
return strieq(std::begin(a), a.size(), std::begin(b), b.size()); return strieq(std::begin(a), a.size(), std::begin(b), b.size());
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment