diff --git a/examples/shrpx_http.cc b/examples/shrpx_http.cc index 8d775cf..4037677 100644 --- a/examples/shrpx_http.cc +++ b/examples/shrpx_http.cc @@ -89,7 +89,7 @@ std::string create_error_html(int status_code) << "
" << get_config()->server_name << " at port " << get_config()->port << "
" - << ""; + << "\n"; return ss.str(); } diff --git a/examples/shrpx_https_upstream.cc b/examples/shrpx_https_upstream.cc index c33c6a9..707b82b 100644 --- a/examples/shrpx_https_upstream.cc +++ b/examples/shrpx_https_upstream.cc @@ -40,16 +40,15 @@ namespace shrpx { namespace { const size_t SHRPX_HTTPS_UPSTREAM_OUTPUT_UPPER_THRES = 512*1024; +const size_t SHRPX_HTTPS_MAX_HEADER_LENGTH = 64*1024; } // namespace HttpsUpstream::HttpsUpstream(ClientHandler *handler) : handler_(handler), htp_(htparser_new()), + current_header_length_(0), ioctrl_(handler->get_bev()) { - if(ENABLE_LOG) { - LOG(INFO) << "HttpsUpstream ctor"; - } htparser_init(htp_, htp_type_request); htparser_set_userdata(htp_, this); } @@ -63,6 +62,11 @@ HttpsUpstream::~HttpsUpstream() } } +void HttpsUpstream::reset_current_header_length() +{ + current_header_length_ = 0; +} + namespace { int htp_msg_begin(htparser *htp) { @@ -71,6 +75,7 @@ int htp_msg_begin(htparser *htp) } HttpsUpstream *upstream; upstream = reinterpret_cast(htparser_get_userdata(htp)); + upstream->reset_current_header_length(); Downstream *downstream = new Downstream(upstream, 0, 0); upstream->add_downstream(downstream); return 0; @@ -207,16 +212,10 @@ htparse_hooks htp_hooks = { }; } // namespace -std::set cache; - // on_read() does not consume all available data in input buffer if // one http request is fully received. int HttpsUpstream::on_read() { - if(cache.count(this) == 0) { - LOG(INFO) << "HttpsUpstream::on_read"; - cache.insert(this); - } bufferevent *bev = handler_->get_bev(); evbuffer *input = bufferevent_get_input(bev); unsigned char *mem = evbuffer_pullup(input, -1); @@ -224,18 +223,25 @@ int HttpsUpstream::on_read() reinterpret_cast(mem), evbuffer_get_length(input)); evbuffer_drain(input, nread); + // Well, actually header length + some body bytes + current_header_length_ += nread; htpparse_error htperr = htparser_get_error(htp_); if(htperr == htparse_error_user) { - pause_read(SHRPX_MSG_BLOCK); - if(ENABLE_LOG) { - LOG(INFO) << " remaining bytes " << evbuffer_get_length(input); + if(current_header_length_ > SHRPX_HTTPS_MAX_HEADER_LENGTH) { + LOG(WARNING) << "Request Header too long:" << current_header_length_ + << " bytes"; + get_client_handler()->set_should_close_after_write(true); + error_reply(400); + } else { + pause_read(SHRPX_MSG_BLOCK); } } else if(htperr != htparse_error_none) { if(ENABLE_LOG) { LOG(INFO) << " http parse failure: " << htparser_get_strerror(htp_); } - return SHRPX_ERR_HTTP_PARSE; + get_client_handler()->set_should_close_after_write(true); + error_reply(400); } return 0; } @@ -298,7 +304,7 @@ void https_downstream_readcb(bufferevent *bev, void *ptr) if(downstream->get_response_state() == Downstream::HEADER_COMPLETE) { delete upstream->get_client_handler(); } else { - upstream->error_reply(downstream, 502); + upstream->error_reply(502); assert(downstream == upstream->get_top_downstream()); upstream->pop_downstream(); delete downstream; @@ -344,7 +350,7 @@ void https_downstream_eventcb(bufferevent *bev, short events, void *ptr) if(ENABLE_LOG) { LOG(INFO) << " Treated as error"; } - upstream->error_reply(downstream, 502); + upstream->error_reply(502); } upstream->pop_downstream(); delete downstream; @@ -360,7 +366,7 @@ void https_downstream_eventcb(bufferevent *bev, short events, void *ptr) } else { status = 502; } - upstream->error_reply(downstream, status); + upstream->error_reply(status); } upstream->pop_downstream(); delete downstream; @@ -369,20 +375,22 @@ void https_downstream_eventcb(bufferevent *bev, short events, void *ptr) } } // namespace -void HttpsUpstream::error_reply(Downstream *downstream, int status_code) +void HttpsUpstream::error_reply(int status_code) { std::string html = http::create_error_html(status_code); std::stringstream ss; ss << "HTTP/1.1 " << http::get_status_string(status_code) << "\r\n" << "Server: " << get_config()->server_name << "\r\n" << "Content-Length: " << html.size() << "\r\n" - << "Content-Type: " << "text/html; charset=UTF-8\r\n" - << "\r\n"; + << "Content-Type: " << "text/html; charset=UTF-8\r\n"; + if(get_client_handler()->get_should_close_after_write()) { + ss << "Connection: close\r\n"; + } + ss << "\r\n"; std::string header = ss.str(); evbuffer *output = bufferevent_get_output(handler_->get_bev()); evbuffer_add(output, header.c_str(), header.size()); evbuffer_add(output, html.c_str(), html.size()); - downstream->set_response_state(Downstream::MSG_COMPLETE); } bufferevent_data_cb HttpsUpstream::get_downstream_readcb() @@ -454,7 +462,7 @@ int HttpsUpstream::on_downstream_header_complete(Downstream *downstream) hdrs += "\r\n"; } } - if(downstream->get_request_connection_close()) { + if(get_client_handler()->get_should_close_after_write()) { hdrs += "Connection: close\r\n"; } hdrs += "\r\n"; diff --git a/examples/shrpx_https_upstream.h b/examples/shrpx_https_upstream.h index 3fa76f9..ae679be 100644 --- a/examples/shrpx_https_upstream.h +++ b/examples/shrpx_https_upstream.h @@ -58,7 +58,7 @@ public: void pop_downstream(); Downstream* get_top_downstream(); Downstream* get_last_downstream(); - void error_reply(Downstream *downstream, int status_code); + void error_reply(int status_code); void pause_read(IOCtrlReason reason); void resume_read(IOCtrlReason reason); @@ -68,9 +68,11 @@ public: const uint8_t *data, size_t len); virtual int on_downstream_body_complete(Downstream *downstream); + void reset_current_header_length(); private: ClientHandler *handler_; htparser *htp_; + size_t current_header_length_; std::deque downstream_queue_; IOControl ioctrl_; };