diff --git a/src/shrpx.cc b/src/shrpx.cc index 8f3a8cc..22cfe34 100644 --- a/src/shrpx.cc +++ b/src/shrpx.cc @@ -356,6 +356,9 @@ void fill_default_config() mod_config()->client_proxy = false; mod_config()->client = false; mod_config()->client_mode = false; + + mod_config()->insecure = false; + mod_config()->cacert = 0; } } // namespace @@ -465,6 +468,15 @@ void print_help(std::ostream& out) << get_config()->backlog << "\n" << " --ciphers= Set allowed cipher list. The format of the\n" << " string is described in OpenSSL ciphers(1).\n" + << " -k, --insecure When used with -p or --client, don't verify\n" + << " backend server's certificate.\n" + << " --cacert= When used with -p or --client, set path to\n" + << " trusted CA certificate file.\n" + << " The file must be in PEM format. It can\n" + << " contain multiple certificates. If the\n" + << " linked OpenSSL is configured to load system\n" + << " wide certificates, they are loaded\n" + << " at startup regardless of this option.\n" << " -h, --help Print this help.\n" << std::endl; } @@ -482,6 +494,7 @@ int main(int argc, char **argv) static option long_options[] = { {"backend", required_argument, 0, 'b' }, {"frontend", required_argument, 0, 'f' }, + {"insecure", no_argument, 0, 'k' }, {"workers", required_argument, 0, 'n' }, {"spdy-max-concurrent-streams", required_argument, 0, 'c' }, {"log-level", required_argument, 0, 'L' }, @@ -506,11 +519,12 @@ int main(int argc, char **argv) {"ciphers", required_argument, &flag, 16 }, {"client", no_argument, &flag, 17 }, {"backend-spdy-window-bits", required_argument, &flag, 18 }, + {"cacert", required_argument, &flag, 19 }, {"help", no_argument, 0, 'h' }, {0, 0, 0, 0 } }; int option_index = 0; - int c = getopt_long(argc, argv, "DL:sb:c:f:n:hp", long_options, + int c = getopt_long(argc, argv, "DL:ksb:c:f:n:hp", long_options, &option_index); if(c == -1) { break; @@ -531,6 +545,9 @@ int main(int argc, char **argv) case 'f': cmdcfgs.push_back(std::make_pair(SHRPX_OPT_FRONTEND, optarg)); break; + case 'k': + cmdcfgs.push_back(std::make_pair(SHRPX_OPT_INSECURE, "yes")); + break; case 'n': cmdcfgs.push_back(std::make_pair(SHRPX_OPT_WORKERS, optarg)); break; @@ -626,6 +643,10 @@ int main(int argc, char **argv) cmdcfgs.push_back(std::make_pair(SHRPX_OPT_BACKEND_SPDY_WINDOW_BITS, optarg)); break; + case 19: + // --cacert + cmdcfgs.push_back(std::make_pair(SHRPX_OPT_CACERT, optarg)); + break; default: break; } diff --git a/src/shrpx_config.cc b/src/shrpx_config.cc index 69c6916..7ab7a9b 100644 --- a/src/shrpx_config.cc +++ b/src/shrpx_config.cc @@ -71,6 +71,8 @@ const char SHRPX_OPT_SYSLOG_FACILITY[] = "syslog-facility"; const char SHRPX_OPT_BACKLOG[] = "backlog"; const char SHRPX_OPT_CIPHERS[] = "ciphers"; const char SHRPX_OPT_CLIENT[] = "client"; +const char SHRPX_OPT_INSECURE[] = "insecure"; +const char SHRPX_OPT_CACERT[] = "cacert"; Config::Config() : verbose(false), @@ -103,7 +105,9 @@ Config::Config() backlog(0), ciphers(0), client(false), - client_mode(false) + client_mode(false), + insecure(false), + cacert(0) {} namespace { @@ -268,6 +272,10 @@ int parse_config(const char *opt, const char *optarg) set_config_str(&mod_config()->ciphers, optarg); } else if(util::strieq(opt, SHRPX_OPT_CLIENT)) { mod_config()->client = util::strieq(optarg, "yes"); + } else if(util::strieq(opt, SHRPX_OPT_INSECURE)) { + mod_config()->insecure = util::strieq(optarg, "yes"); + } else if(util::strieq(opt, SHRPX_OPT_CACERT)) { + set_config_str(&mod_config()->cacert, optarg); } else if(util::strieq(opt, "conf")) { LOG(WARNING) << "conf is ignored"; } else { diff --git a/src/shrpx_config.h b/src/shrpx_config.h index 4413f27..27e4655 100644 --- a/src/shrpx_config.h +++ b/src/shrpx_config.h @@ -63,6 +63,8 @@ extern const char SHRPX_OPT_SYSLOG_FACILITY[]; extern const char SHRPX_OPT_BACKLOG[]; extern const char SHRPX_OPT_CIPHERS[]; extern const char SHRPX_OPT_CLIENT[]; +extern const char SHRPX_OPT_INSECURE[]; +extern const char SHRPX_OPT_CACERT[]; union sockaddr_union { sockaddr sa; @@ -112,6 +114,8 @@ struct Config { bool client; // true if --client or --client-proxy are enabled. bool client_mode; + bool insecure; + char *cacert; Config(); }; diff --git a/src/shrpx_spdy_session.cc b/src/shrpx_spdy_session.cc index 7694a45..25bc85b 100644 --- a/src/shrpx_spdy_session.cc +++ b/src/shrpx_spdy_session.cc @@ -37,6 +37,7 @@ #include "shrpx_error.h" #include "shrpx_spdy_downstream_connection.h" #include "shrpx_client_handler.h" +#include "shrpx_ssl.h" #include "util.h" using namespace spdylay; @@ -215,7 +216,8 @@ void eventcb(bufferevent *bev, short events, void *ptr) LOG(INFO) << "Downstream spdy connection established. " << spdy; } spdy->connected(); - if(spdy->on_connect() != 0) { + if((!get_config()->insecure && spdy->check_cert() != 0) || + spdy->on_connect() != 0) { spdy->disconnect(); return; } @@ -233,6 +235,11 @@ void eventcb(bufferevent *bev, short events, void *ptr) } } // namespace +int SpdySession::check_cert() +{ + return ssl::check_cert(ssl_); +} + int SpdySession::initiate_connection() { int rv; diff --git a/src/shrpx_spdy_session.h b/src/shrpx_spdy_session.h index 5787044..4f7ffa8 100644 --- a/src/shrpx_spdy_session.h +++ b/src/shrpx_spdy_session.h @@ -51,6 +51,8 @@ public: int init_notification(); + int check_cert(); + int disconnect(); int initiate_connection(); void connected(); diff --git a/src/shrpx_ssl.cc b/src/shrpx_ssl.cc index 89f58d3..6337544 100644 --- a/src/shrpx_ssl.cc +++ b/src/shrpx_ssl.cc @@ -29,7 +29,12 @@ #include #include +#include +#include + #include +#include +#include #include #include @@ -40,6 +45,9 @@ #include "shrpx_client_handler.h" #include "shrpx_config.h" #include "shrpx_accesslog.h" +#include "util.h" + +using namespace spdylay; namespace shrpx { @@ -182,6 +190,20 @@ SSL_CTX* create_ssl_client_context() SSL_CTX_set_mode(ssl_ctx, SSL_MODE_AUTO_RETRY); SSL_CTX_set_mode(ssl_ctx, SSL_MODE_RELEASE_BUFFERS); + if(SSL_CTX_set_default_verify_paths(ssl_ctx) != 1) { + LOG(WARNING) << "Could not load system trusted ca certificates: " + << ERR_error_string(ERR_get_error(), NULL); + } + + if(get_config()->cacert) { + if(SSL_CTX_load_verify_locations(ssl_ctx, get_config()->cacert, 0) != 1) { + LOG(FATAL) << "Could not load trusted ca certificates from " + << get_config()->cacert << ": " + << ERR_error_string(ERR_get_error(), NULL); + DIE(); + } + } + SSL_CTX_set_next_proto_select_cb(ssl_ctx, select_next_proto_cb, 0); return ssl_ctx; } @@ -228,6 +250,184 @@ ClientHandler* accept_ssl_connection(event_base *evbase, SSL_CTX *ssl_ctx, } } +namespace { +bool numeric_host(const char *hostname) +{ + struct addrinfo hints; + struct addrinfo* res; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + if(getaddrinfo(hostname, 0, &hints, &res)) { + return false; + } + freeaddrinfo(res); + return true; +} +} // namespace + +namespace { +bool tls_hostname_match(const char *pattern, const char *hostname) +{ + const char *ptWildcard = strchr(pattern, '*'); + if(ptWildcard == 0) { + return util::strieq(pattern, hostname); + } + const char *ptLeftLabelEnd = strchr(pattern, '.'); + bool wildcardEnabled = true; + // Do case-insensitive match. At least 2 dots are required to enable + // wildcard match. Also wildcard must be in the left-most label. + // Don't attempt to match a presented identifier where the wildcard + // character is embedded within an A-label. + if(ptLeftLabelEnd == 0 || strchr(ptLeftLabelEnd+1, '.') == 0 || + ptLeftLabelEnd < ptWildcard || util::istartsWith(pattern, "xn--")) { + wildcardEnabled = false; + } + if(!wildcardEnabled) { + return util::strieq(pattern, hostname); + } + const char *hnLeftLabelEnd = strchr(hostname, '.'); + if(hnLeftLabelEnd == 0 || !util::strieq(ptLeftLabelEnd, hnLeftLabelEnd)) { + return false; + } + // Perform wildcard match. Here '*' must match at least one + // character. + if(hnLeftLabelEnd - hostname < ptLeftLabelEnd - pattern) { + return false; + } + return util::istartsWith(hostname, hnLeftLabelEnd, pattern, ptWildcard) && + util::iendsWith(hostname, hnLeftLabelEnd, ptWildcard+1, ptLeftLabelEnd); +} +} // namespace + +namespace { +int verify_hostname(const char *hostname, + const sockaddr_union *su, + size_t salen, + const std::vector& dns_names, + const std::vector& ip_addrs, + const std::string& common_name) +{ + if(numeric_host(hostname)) { + if(ip_addrs.empty()) { + return util::strieq(common_name.c_str(), hostname) ? 0 : -1; + } + const void *saddr; + switch(su->storage.ss_family) { + case AF_INET: + saddr = &su->in.sin_addr; + break; + case AF_INET6: + saddr = &su->in6.sin6_addr; + break; + default: + return -1; + } + for(size_t i = 0; i < ip_addrs.size(); ++i) { + if(salen == ip_addrs[i].size() && + memcmp(saddr, ip_addrs[i].c_str(), salen) == 0) { + return 0; + } + } + } else { + if(dns_names.empty()) { + return tls_hostname_match(common_name.c_str(), hostname) ? 0 : -1; + } + for(size_t i = 0; i < dns_names.size(); ++i) { + if(tls_hostname_match(dns_names[i].c_str(), hostname)) { + return 0; + } + } + } + return -1; +} +} // namespace + +int check_cert(SSL *ssl) +{ + X509 *cert = SSL_get_peer_certificate(ssl); + if(!cert) { + LOG(ERROR) << "No certificate found"; + return -1; + } + util::auto_delete cert_deleter(cert, X509_free); + long verify_res = SSL_get_verify_result(ssl); + if(verify_res != X509_V_OK) { + LOG(ERROR) << "Certificate verification failed: " + << X509_verify_cert_error_string(verify_res); + return -1; + } + std::string common_name; + std::vector dns_names; + std::vector ip_addrs; + GENERAL_NAMES* altnames; + altnames = reinterpret_cast + (X509_get_ext_d2i(cert, NID_subject_alt_name, 0, 0)); + if(altnames) { + util::auto_delete altnames_deleter(altnames, + GENERAL_NAMES_free); + size_t n = sk_GENERAL_NAME_num(altnames); + for(size_t i = 0; i < n; ++i) { + const GENERAL_NAME *altname = sk_GENERAL_NAME_value(altnames, i); + if(altname->type == GEN_DNS) { + const char *name; + name = reinterpret_cast(ASN1_STRING_data(altname->d.ia5)); + if(!name) { + continue; + } + size_t len = ASN1_STRING_length(altname->d.ia5); + if(std::find(name, name+len, '\0') != name+len) { + // Embedded NULL is not permitted. + continue; + } + dns_names.push_back(std::string(name, len)); + } else if(altname->type == GEN_IPADD) { + const unsigned char *ip_addr = altname->d.iPAddress->data; + if(!ip_addr) { + continue; + } + size_t len = altname->d.iPAddress->length; + ip_addrs.push_back(std::string(reinterpret_cast(ip_addr), + len)); + } + } + } + X509_NAME *subjectname = X509_get_subject_name(cert); + if(!subjectname) { + LOG(ERROR) << "Could not get X509 name object from the certificate."; + return -1; + } + int lastpos = -1; + while(1) { + lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, + lastpos); + if(lastpos == -1) { + break; + } + X509_NAME_ENTRY *entry = X509_NAME_get_entry(subjectname, lastpos); + unsigned char *out; + int outlen = ASN1_STRING_to_UTF8(&out, X509_NAME_ENTRY_get_data(entry)); + if(outlen < 0) { + continue; + } + if(std::find(out, out+outlen, '\0') != out+outlen) { + // Embedded NULL is not permitted. + continue; + } + common_name.assign(&out[0], &out[outlen]); + OPENSSL_free(out); + break; + } + if(verify_hostname(get_config()->downstream_host, + &get_config()->downstream_addr, + get_config()->downstream_addrlen, + dns_names, ip_addrs, common_name) != 0) { + LOG(ERROR) << "Certificate verification failed: hostname does not match"; + return -1; + } + return 0; +} + namespace { pthread_mutex_t *ssl_locks; } // namespace diff --git a/src/shrpx_ssl.h b/src/shrpx_ssl.h index 170e5a9..87e476c 100644 --- a/src/shrpx_ssl.h +++ b/src/shrpx_ssl.h @@ -46,6 +46,8 @@ ClientHandler* accept_ssl_connection(event_base *evbase, SSL_CTX *ssl_ctx, evutil_socket_t fd, sockaddr *addr, int addrlen); +int check_cert(SSL *ssl); + void setup_ssl_lock(); void teardown_ssl_lock(); diff --git a/src/util.h b/src/util.h index 651ea81..99ab72e 100644 --- a/src/util.h +++ b/src/util.h @@ -36,6 +36,47 @@ namespace spdylay { namespace util { +template +class auto_delete { +private: + T obj_; + void (*deleter_)(T); +public: + auto_delete(T obj, void (*deleter)(T)):obj_(obj), deleter_(deleter) {} + + ~auto_delete() + { + deleter_(obj_); + } +}; + +template +class auto_delete_d { +private: + T obj_; +public: + auto_delete_d(T obj):obj_(obj) {} + + ~auto_delete_d() + { + delete obj_; + } +}; + +template +class auto_delete_r { +private: + T obj_; + R (*deleter_)(T); +public: + auto_delete_r(T obj, R (*deleter)(T)):obj_(obj), deleter_(deleter) {} + + ~auto_delete_r() + { + (void)deleter_(obj_); + } +}; + extern const std::string DEFAULT_STRIP_CHARSET; template @@ -233,8 +274,23 @@ bool endsWith return std::equal(first2, last2, last1-(last2-first2)); } +template +bool iendsWith +(InputIterator1 first1, + InputIterator1 last1, + InputIterator2 first2, + InputIterator2 last2) +{ + if(last1-first1 < last2-first2) { + return false; + } + return std::equal(first2, last2, last1-(last2-first2), CaseCmp()); +} + bool endsWith(const std::string& a, const std::string& b); +bool strieq(const std::string& a, const std::string& b); + bool strieq(const char *a, const char *b); bool strifind(const char *a, const char *b);