Commit dbbf3a4a authored by Tatsuhiro Tsujikawa's avatar Tatsuhiro Tsujikawa

nghttpx: Refactor TLS hostname match

parent f25fd09b
......@@ -69,6 +69,8 @@ int main(int argc, char *argv[]) {
shrpx::test_shrpx_ssl_create_lookup_tree) ||
!CU_add_test(pSuite, "ssl_cert_lookup_tree_add_cert_from_file",
shrpx::test_shrpx_ssl_cert_lookup_tree_add_cert_from_file) ||
!CU_add_test(pSuite, "ssl_tls_hostname_match",
shrpx::test_shrpx_ssl_tls_hostname_match) ||
!CU_add_test(pSuite, "http2_add_header", shrpx::test_http2_add_header) ||
!CU_add_test(pSuite, "http2_get_header", shrpx::test_http2_get_header) ||
!CU_add_test(pSuite, "http2_copy_headers_to_nva",
......
......@@ -780,27 +780,35 @@ ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
return new ClientHandler(worker, fd, ssl, host, service);
}
namespace {
bool tls_hostname_match(const char *pattern, const char *hostname) {
const char *ptWildcard = strchr(pattern, '*');
if (ptWildcard == nullptr) {
return util::strieq(pattern, hostname);
bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname,
size_t hlen) {
auto pend = pattern + plen;
auto ptWildcard = std::find(pattern, pend, '*');
if (ptWildcard == pend) {
return util::strieq(pattern, plen, hostname, hlen);
}
const char *ptLeftLabelEnd = strchr(pattern, '.');
bool wildcardEnabled = true;
auto ptLeftLabelEnd = std::find(pattern, pend, '.');
auto wildcardEnabled = true;
// Do case-insensitive match. At least 2 dots are required to enable
// wildcard match. Also wildcard must be in the left-most label.
// Don't attempt to match a presented identifier where the wildcard
// character is embedded within an A-label.
if (ptLeftLabelEnd == 0 || strchr(ptLeftLabelEnd + 1, '.') == 0 ||
ptLeftLabelEnd < ptWildcard || util::istarts_with(pattern, "xn--")) {
if (ptLeftLabelEnd == pend ||
std::find(ptLeftLabelEnd + 1, pend, '.') == pend ||
ptLeftLabelEnd < ptWildcard ||
util::istarts_with(pattern, plen, "xn--")) {
wildcardEnabled = false;
}
if (!wildcardEnabled) {
return util::strieq(pattern, hostname);
return util::strieq(pattern, plen, hostname, hlen);
}
const char *hnLeftLabelEnd = strchr(hostname, '.');
if (hnLeftLabelEnd == 0 || !util::strieq(ptLeftLabelEnd, hnLeftLabelEnd)) {
auto hend = hostname + hlen;
auto hnLeftLabelEnd = std::find(hostname, hend, '.');
if (hnLeftLabelEnd == hend ||
!util::strieq(ptLeftLabelEnd, pend, hnLeftLabelEnd, hend)) {
return false;
}
// Perform wildcard match. Here '*' must match at least one
......@@ -812,107 +820,143 @@ bool tls_hostname_match(const char *pattern, const char *hostname) {
util::iends_with(hostname, hnLeftLabelEnd, ptWildcard + 1,
ptLeftLabelEnd);
}
} // namespace
namespace {
int verify_hostname(const char *hostname, const Address *addr,
const std::vector<std::string> &dns_names,
const std::vector<std::string> &ip_addrs,
const std::string &common_name) {
if (util::numeric_host(hostname)) {
if (ip_addrs.empty()) {
return util::strieq(common_name.c_str(), hostname) ? 0 : -1;
}
const void *saddr;
switch (addr->su.storage.ss_family) {
case AF_INET:
saddr = &addr->su.in.sin_addr;
break;
case AF_INET6:
saddr = &addr->su.in6.sin6_addr;
ssize_t get_common_name(unsigned char **out_ptr, X509 *cert) {
auto subjectname = X509_get_subject_name(cert);
if (!subjectname) {
LOG(WARN) << "Could not get X509 name object from the certificate.";
return -1;
}
int lastpos = -1;
for (;;) {
lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos);
if (lastpos == -1) {
break;
default:
return -1;
}
for (size_t i = 0; i < ip_addrs.size(); ++i) {
if (addr->len == ip_addrs[i].size() &&
memcmp(saddr, ip_addrs[i].c_str(), addr->len) == 0) {
return 0;
}
auto entry = X509_NAME_get_entry(subjectname, lastpos);
auto outlen = ASN1_STRING_to_UTF8(out_ptr, X509_NAME_ENTRY_get_data(entry));
if (outlen < 0) {
continue;
}
} else {
if (dns_names.empty()) {
return tls_hostname_match(common_name.c_str(), hostname) ? 0 : -1;
if (std::find(*out_ptr, *out_ptr + outlen, '\0') != *out_ptr + outlen) {
// Embedded NULL is not permitted.
continue;
}
for (size_t i = 0; i < dns_names.size(); ++i) {
if (tls_hostname_match(dns_names[i].c_str(), hostname)) {
return outlen;
}
return -1;
}
} // namespace
namespace {
int verify_numeric_hostname(X509 *cert, const char *hostname, size_t hlen,
const Address *addr) {
const void *saddr;
switch (addr->su.storage.ss_family) {
case AF_INET:
saddr = &addr->su.in.sin_addr;
break;
case AF_INET6:
saddr = &addr->su.in6.sin6_addr;
break;
default:
return -1;
}
auto altnames = static_cast<GENERAL_NAMES *>(
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
if (altnames) {
auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
auto n = sk_GENERAL_NAME_num(altnames);
for (size_t i = 0; i < n; ++i) {
auto altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type != GEN_IPADD) {
continue;
}
auto ip_addr = altname->d.iPAddress->data;
if (!ip_addr) {
continue;
}
auto ip_addrlen = altname->d.iPAddress->length;
if (addr->len == ip_addrlen && memcmp(saddr, ip_addr, ip_addrlen) == 0) {
return 0;
}
}
}
unsigned char *cn;
auto cnlen = get_common_name(&cn, cert);
if (cnlen == -1) {
return -1;
}
// cn is not NULL terminated
auto rv = util::streq(hostname, hlen, cn, cnlen);
OPENSSL_free(cn);
if (rv) {
return 0;
}
return -1;
}
} // namespace
void get_altnames(X509 *cert, std::vector<std::string> &dns_names,
std::vector<std::string> &ip_addrs,
std::string &common_name) {
GENERAL_NAMES *altnames = static_cast<GENERAL_NAMES *>(
namespace {
int verify_hostname(X509 *cert, const char *hostname, size_t hlen,
const Address *addr) {
if (util::numeric_host(hostname)) {
return verify_numeric_hostname(cert, hostname, hlen, addr);
}
auto altnames = static_cast<GENERAL_NAMES *>(
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
if (altnames) {
auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
size_t n = sk_GENERAL_NAME_num(altnames);
auto n = sk_GENERAL_NAME_num(altnames);
for (size_t i = 0; i < n; ++i) {
const GENERAL_NAME *altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type == GEN_DNS) {
const char *name;
name = reinterpret_cast<char *>(ASN1_STRING_data(altname->d.ia5));
if (!name) {
continue;
}
size_t len = ASN1_STRING_length(altname->d.ia5);
if (std::find(name, name + len, '\0') != name + len) {
// Embedded NULL is not permitted.
continue;
}
dns_names.push_back(std::string(name, len));
} else if (altname->type == GEN_IPADD) {
const unsigned char *ip_addr = altname->d.iPAddress->data;
if (!ip_addr) {
continue;
}
size_t len = altname->d.iPAddress->length;
ip_addrs.push_back(
std::string(reinterpret_cast<const char *>(ip_addr), len));
auto altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type != GEN_DNS) {
continue;
}
auto name = reinterpret_cast<char *>(ASN1_STRING_data(altname->d.ia5));
if (!name) {
continue;
}
auto len = ASN1_STRING_length(altname->d.ia5);
if (std::find(name, name + len, '\0') != name + len) {
// Embedded NULL is not permitted.
continue;
}
if (tls_hostname_match(name, len, hostname, hlen)) {
return 0;
}
}
}
X509_NAME *subjectname = X509_get_subject_name(cert);
if (!subjectname) {
LOG(WARN) << "Could not get X509 name object from the certificate.";
return;
unsigned char *cn;
auto cnlen = get_common_name(&cn, cert);
if (cnlen == -1) {
return -1;
}
int lastpos = -1;
while (1) {
lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos);
if (lastpos == -1) {
break;
}
X509_NAME_ENTRY *entry = X509_NAME_get_entry(subjectname, lastpos);
unsigned char *out;
int outlen = ASN1_STRING_to_UTF8(&out, X509_NAME_ENTRY_get_data(entry));
if (outlen < 0) {
continue;
}
if (std::find(out, out + outlen, '\0') != out + outlen) {
// Embedded NULL is not permitted.
continue;
}
common_name.assign(&out[0], &out[outlen]);
OPENSSL_free(out);
break;
auto rv = util::strieq(hostname, hlen, cn, cnlen);
OPENSSL_free(cn);
if (rv) {
return 0;
}
return -1;
}
} // namespace
int check_cert(SSL *ssl, const DownstreamAddr *addr) {
auto cert = SSL_get_peer_certificate(ssl);
......@@ -921,21 +965,16 @@ int check_cert(SSL *ssl, const DownstreamAddr *addr) {
return -1;
}
auto cert_deleter = defer(X509_free, cert);
long verify_res = SSL_get_verify_result(ssl);
auto verify_res = SSL_get_verify_result(ssl);
if (verify_res != X509_V_OK) {
LOG(ERROR) << "Certificate verification failed: "
<< X509_verify_cert_error_string(verify_res);
return -1;
}
std::string common_name;
std::vector<std::string> dns_names;
std::vector<std::string> ip_addrs;
get_altnames(cert, dns_names, ip_addrs, common_name);
auto hostname = get_config()->backend_tls_sni_name
? get_config()->backend_tls_sni_name.get()
: addr->host.get();
if (verify_hostname(hostname, &addr->addr, dns_names, ip_addrs,
common_name) != 0) {
if (verify_hostname(cert, hostname, strlen(hostname), &addr->addr) != 0) {
LOG(ERROR) << "Certificate verification failed: hostname does not match";
return -1;
}
......@@ -969,7 +1008,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname,
// some restrictions for wildcard hostname. We just ignore
// these rules here but do the proper check when we do the
// match.
node->wildcard_certs.emplace_back(hostname, ssl_ctx);
node->wildcard_certs.push_back({ssl_ctx, hostname, len});
return;
}
......@@ -986,7 +1025,7 @@ void cert_lookup_tree_add_cert(CertNode *node, SSL_CTX *ssl_ctx, char *hostname,
new_node->ssl_ctx = ssl_ctx;
} else {
new_node->ssl_ctx = nullptr;
new_node->wildcard_certs.emplace_back(hostname, ssl_ctx);
new_node->wildcard_certs.push_back({ssl_ctx, hostname, len});
}
node->next.push_back(std::move(new_node));
return;
......@@ -1073,9 +1112,11 @@ SSL_CTX *cert_lookup_tree_lookup(CertNode *node, const char *hostname,
// one character.
return nullptr;
}
for (const auto &wildcert : node->wildcard_certs) {
if (tls_hostname_match(wildcert.first, hostname)) {
return wildcert.second;
if (tls_hostname_match(wildcert.hostname, wildcert.hostnamelen, hostname,
len)) {
return wildcert.ssl_ctx;
}
}
auto c = util::lowcase(hostname[j]);
......@@ -1111,14 +1152,43 @@ int cert_lookup_tree_add_cert_from_file(CertLookupTree *lt, SSL_CTX *ssl_ctx,
return -1;
}
auto cert_deleter = defer(X509_free, cert);
std::string common_name;
std::vector<std::string> dns_names;
std::vector<std::string> ip_addrs;
get_altnames(cert, dns_names, ip_addrs, common_name);
for (auto &dns_name : dns_names) {
lt->add_cert(ssl_ctx, dns_name.c_str(), dns_name.size());
}
lt->add_cert(ssl_ctx, common_name.c_str(), common_name.size());
auto altnames = static_cast<GENERAL_NAMES *>(
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
if (altnames) {
auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
auto n = sk_GENERAL_NAME_num(altnames);
for (size_t i = 0; i < n; ++i) {
auto altname = sk_GENERAL_NAME_value(altnames, i);
if (altname->type != GEN_DNS) {
continue;
}
auto name = reinterpret_cast<char *>(ASN1_STRING_data(altname->d.ia5));
if (!name) {
continue;
}
auto len = ASN1_STRING_length(altname->d.ia5);
if (std::find(name, name + len, '\0') != name + len) {
// Embedded NULL is not permitted.
continue;
}
lt->add_cert(ssl_ctx, name, len);
}
}
unsigned char *cn;
auto cnlen = get_common_name(&cn, cert);
if (cnlen == -1) {
return 0;
}
lt->add_cert(ssl_ctx, reinterpret_cast<char *>(cn), cnlen);
OPENSSL_free(cn);
return 0;
}
......
......@@ -108,10 +108,16 @@ void get_altnames(X509 *cert, std::vector<std::string> &dns_names,
// them. If there is a match, its SSL_CTX is returned. If none
// matches, query is continued to the next character.
struct WildcardCert {
SSL_CTX *ssl_ctx;
char *hostname;
size_t hostnamelen;
};
struct CertNode {
// list of wildcard domain name and its SSL_CTX pair, the wildcard
// '*' appears in this position.
std::vector<std::pair<char *, SSL_CTX *>> wildcard_certs;
std::vector<WildcardCert> wildcard_certs;
// Next CertNode index of CertLookupTree::nodes
std::vector<std::unique_ptr<CertNode>> next;
// SSL_CTX for exact match
......@@ -198,6 +204,13 @@ SSL *create_ssl(SSL_CTX *ssl_ctx);
// Returns true if SSL/TLS is enabled on downstream
bool downstream_tls_enabled();
// Performs TLS hostname match. |pattern| of length |plen| can
// contain wildcard character '*', which matches prefix of target
// hostname. There are several restrictions to make wildcard work.
// The matching algorithm is based on RFC 6125.
bool tls_hostname_match(const char *pattern, size_t plen, const char *hostname,
size_t hlen);
} // namespace ssl
} // namespace shrpx
......
......@@ -115,4 +115,37 @@ void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void) {
SSL_CTX_free(ssl_ctx);
}
template <size_t N, size_t M>
bool tls_hostname_match_wrapper(const char(&pattern)[N],
const char(&hostname)[M]) {
return ssl::tls_hostname_match(pattern, N, hostname, M);
}
void test_shrpx_ssl_tls_hostname_match(void) {
CU_ASSERT(tls_hostname_match_wrapper("example.com", "example.com"));
CU_ASSERT(tls_hostname_match_wrapper("example.com", "EXAMPLE.com"));
// check wildcard
CU_ASSERT(tls_hostname_match_wrapper("*.example.com", "www.example.com"));
CU_ASSERT(tls_hostname_match_wrapper("*w.example.com", "www.example.com"));
CU_ASSERT(tls_hostname_match_wrapper("www*.example.com", "www1.example.com"));
CU_ASSERT(
tls_hostname_match_wrapper("www*.example.com", "WWW12.EXAMPLE.com"));
// at least 2 dots are required after '*'
CU_ASSERT(!tls_hostname_match_wrapper("*.com", "example.com"));
CU_ASSERT(!tls_hostname_match_wrapper("*", "example.com"));
// '*' must be in left most label
CU_ASSERT(
!tls_hostname_match_wrapper("blog.*.example.com", "blog.my.example.com"));
// prefix is wrong
CU_ASSERT(
!tls_hostname_match_wrapper("client*.example.com", "server.example.com"));
// '*' must match at least one character
CU_ASSERT(!tls_hostname_match_wrapper("www*.example.com", "www.example.com"));
CU_ASSERT(!tls_hostname_match_wrapper("example.com", "nghttp2.org"));
CU_ASSERT(!tls_hostname_match_wrapper("www.example.com", "example.com"));
CU_ASSERT(!tls_hostname_match_wrapper("example.com", "www.example.com"));
}
} // namespace shrpx
......@@ -33,6 +33,7 @@ namespace shrpx {
void test_shrpx_ssl_create_lookup_tree(void);
void test_shrpx_ssl_cert_lookup_tree_add_cert_from_file(void);
void test_shrpx_ssl_tls_hostname_match(void);
} // namespace shrpx
......
......@@ -258,6 +258,15 @@ bool strieq(InputIt1 a, size_t alen, InputIt2 b, size_t blen) {
return std::equal(a, a + alen, b, CaseCmp());
}
template <typename InputIt1, typename InputIt2>
bool strieq(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) {
if (std::distance(first1, last1) != std::distance(first2, last2)) {
return false;
}
return std::equal(first1, last1, first2, CaseCmp());
}
inline bool strieq(const std::string &a, const std::string &b) {
return strieq(std::begin(a), a.size(), std::begin(b), b.size());
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment