From 327df6682e516317c9affce4254ddd6975872fdb Mon Sep 17 00:00:00 2001 From: echel0n Date: Mon, 16 Jun 2014 21:54:00 -0700 Subject: [PATCH] Updated tornado to latest stable code, fixes issues with auto-reload --- SickBeard.py | 2 +- sickbeard/__init__.py | 8 +- sickbeard/webserve.py | 1 - sickbeard/webserveInit.py | 6 +- tornado/__init__.py | 4 +- tornado/auth.py | 94 ++- tornado/autoreload.py | 2 +- tornado/concurrent.py | 129 +++- tornado/curl_httpclient.py | 8 +- tornado/gen.py | 470 +++++++++---- tornado/http1connection.py | 650 ++++++++++++++++++ tornado/httpclient.py | 74 +- tornado/httpserver.py | 561 +++++---------- tornado/httputil.py | 401 ++++++++++- tornado/ioloop.py | 184 +++-- tornado/iostream.py | 640 +++++++++++++---- tornado/log.py | 10 +- tornado/netutil.py | 51 +- tornado/options.py | 18 +- tornado/platform/asyncio.py | 40 +- tornado/platform/auto.py | 4 + tornado/platform/common.py | 3 +- tornado/platform/kqueue.py | 2 +- tornado/platform/select.py | 2 +- tornado/platform/twisted.py | 36 +- tornado/process.py | 14 +- tornado/simple_httpclient.py | 267 ++++--- tornado/stack_context.py | 12 + tornado/tcpclient.py | 179 +++++ tornado/tcpserver.py | 22 +- tornado/template.py | 11 +- tornado/test/__main__.py | 14 + tornado/test/auth_test.py | 27 + tornado/test/concurrent_test.py | 8 - tornado/test/curl_httpclient_test.py | 20 + tornado/test/gen_test.py | 138 +++- .../test/gettext_translations/extract_me.py | 11 + tornado/test/httpclient_test.py | 11 +- tornado/test/httpserver_test.py | 440 ++++++++++-- tornado/test/import_test.py | 1 + tornado/test/ioloop_test.py | 122 +++- tornado/test/iostream_test.py | 385 ++++++++++- tornado/test/log_test.py | 49 ++ tornado/test/netutil_test.py | 26 +- tornado/test/runtests.py | 10 +- tornado/test/simple_httpclient_test.py | 145 +++- tornado/test/stack_context_test.py | 30 +- tornado/test/tcpclient_test.py | 278 ++++++++ tornado/test/template_test.py | 21 + tornado/test/testing_test.py | 63 +- tornado/test/util.py | 17 +- tornado/test/util_test.py | 16 +- tornado/test/web_test.py | 319 ++++++++- tornado/test/websocket_test.py | 133 +++- tornado/test/wsgi_test.py | 25 +- tornado/testing.py | 100 ++- tornado/util.py | 48 +- tornado/web.py | 522 ++++++++------ tornado/websocket.py | 225 ++++-- tornado/wsgi.py | 250 ++++--- 60 files changed, 5723 insertions(+), 1636 deletions(-) create mode 100644 tornado/http1connection.py create mode 100644 tornado/tcpclient.py create mode 100644 tornado/test/__main__.py create mode 100644 tornado/test/gettext_translations/extract_me.py create mode 100644 tornado/test/tcpclient_test.py diff --git a/SickBeard.py b/SickBeard.py index aa0a4c6b..79267797 100755 --- a/SickBeard.py +++ b/SickBeard.py @@ -389,9 +389,9 @@ def main(): io_loop.add_timeout(datetime.timedelta(seconds=5), startup) # autoreload. + tornado.autoreload.add_reload_hook(autoreload_shutdown) if sickbeard.AUTO_UPDATE: tornado.autoreload.start(io_loop) - tornado.autoreload.add_reload_hook(autoreload_shutdown) # start IOLoop. io_loop.start() diff --git a/sickbeard/__init__.py b/sickbeard/__init__.py index f4c76688..98c684d5 100644 --- a/sickbeard/__init__.py +++ b/sickbeard/__init__.py @@ -1824,7 +1824,9 @@ def getEpList(epIDs, showid=None): def autoreload_shutdown(): logger.log('SickRage is now auto-reloading, please stand by ...') - webserveInit.server.stop() + + # halt all tasks halt() - saveAll() - cleanup_tornado_sockets(IOLoop.current()) + + # save settings + saveAll() \ No newline at end of file diff --git a/sickbeard/webserve.py b/sickbeard/webserve.py index c2557ef8..6dc31c1e 100644 --- a/sickbeard/webserve.py +++ b/sickbeard/webserve.py @@ -3304,7 +3304,6 @@ class Home(IndexHandler): # auto-reload tornado.autoreload.start(IOLoop.current()) - tornado.autoreload.add_reload_hook(sickbeard.autoreload_shutdown) updated = sickbeard.versionCheckScheduler.action.update() # @UndefinedVariable diff --git a/sickbeard/webserveInit.py b/sickbeard/webserveInit.py index 6f8e0e12..5472c3a5 100644 --- a/sickbeard/webserveInit.py +++ b/sickbeard/webserveInit.py @@ -139,7 +139,10 @@ def initWebServer(options={}): logger.log(u"Starting SickRage on " + protocol + "://" + str(options['host']) + ":" + str( options['port']) + "/") - server.listen(options['port'], options['host']) + try: + server.listen(options['port'], options['host']) + except: + pass def shutdown(): global server @@ -147,7 +150,6 @@ def shutdown(): logger.log('Shutting down tornado') try: IOLoop.current().stop() - server.stop() except RuntimeError: pass except: diff --git a/tornado/__init__.py b/tornado/__init__.py index 65d07e17..81900d20 100644 --- a/tornado/__init__.py +++ b/tornado/__init__.py @@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement # is zero for an official release, positive for a development branch, # or negative for a release candidate or beta (after the base version # number has been incremented) -version = "3.2.2" -version_info = (3, 2, 2, 0) +version = "4.0.dev1" +version_info = (4, 0, 0, -100) diff --git a/tornado/auth.py b/tornado/auth.py index 9baac9ba..f8dadb66 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -34,15 +34,29 @@ See the individual service classes below for complete documentation. Example usage for Google OpenID:: - class GoogleLoginHandler(tornado.web.RequestHandler, - tornado.auth.GoogleMixin): + class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, + tornado.auth.GoogleOAuth2Mixin): @tornado.gen.coroutine def get(self): - if self.get_argument("openid.mode", None): - user = yield self.get_authenticated_user() - # Save the user with e.g. set_secure_cookie() + if self.get_argument('code', False): + user = yield self.get_authenticated_user( + redirect_uri='http://your.site.com/auth/google', + code=self.get_argument('code')) + # Save the user with e.g. set_secure_cookie else: - yield self.authenticate_redirect() + yield self.authorize_redirect( + redirect_uri='http://your.site.com/auth/google', + client_id=self.settings['google_oauth']['key'], + scope=['profile', 'email'], + response_type='code', + extra_params={'approval_prompt': 'auto'}) + +.. versionchanged:: 4.0 + All of the callback interfaces in this module are now guaranteed + to run their callback with an argument of ``None`` on error. + Previously some functions would do this while others would simply + terminate the request on their own. This change also ensures that + errors are more consistently reported through the ``Future`` interfaces. """ from __future__ import absolute_import, division, print_function, with_statement @@ -61,6 +75,7 @@ from tornado import httpclient from tornado import escape from tornado.httputil import url_concat from tornado.log import gen_log +from tornado.stack_context import ExceptionStackContext from tornado.util import bytes_type, u, unicode_type, ArgReplacer try: @@ -73,6 +88,11 @@ try: except ImportError: import urllib as urllib_parse # py2 +try: + long # py2 +except NameError: + long = int # py3 + class AuthError(Exception): pass @@ -103,7 +123,14 @@ def _auth_return_future(f): if callback is not None: future.add_done_callback( functools.partial(_auth_future_to_callback, callback)) - f(*args, **kwargs) + def handle_exception(typ, value, tb): + if future.done(): + return False + else: + future.set_exc_info((typ, value, tb)) + return True + with ExceptionStackContext(handle_exception): + f(*args, **kwargs) return future return wrapper @@ -161,7 +188,7 @@ class OpenIdMixin(object): url = self._OPENID_ENDPOINT if http_client is None: http_client = self.get_auth_http_client() - http_client.fetch(url, self.async_callback( + http_client.fetch(url, functools.partial( self._on_authentication_verified, callback), method="POST", body=urllib_parse.urlencode(args)) @@ -333,7 +360,7 @@ class OAuthMixin(object): http_client.fetch( self._oauth_request_token_url(callback_uri=callback_uri, extra_params=extra_params), - self.async_callback( + functools.partial( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri, @@ -341,7 +368,7 @@ class OAuthMixin(object): else: http_client.fetch( self._oauth_request_token_url(), - self.async_callback( + functools.partial( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri, callback)) @@ -378,7 +405,7 @@ class OAuthMixin(object): if http_client is None: http_client = self.get_auth_http_client() http_client.fetch(self._oauth_access_token_url(token), - self.async_callback(self._on_access_token, callback)) + functools.partial(self._on_access_token, callback)) def _oauth_request_token_url(self, callback_uri=None, extra_params=None): consumer_token = self._oauth_consumer_token() @@ -455,7 +482,7 @@ class OAuthMixin(object): access_token = _oauth_parse_response(response.body) self._oauth_get_user_future(access_token).add_done_callback( - self.async_callback(self._on_oauth_get_user, access_token, future)) + functools.partial(self._on_oauth_get_user, access_token, future)) def _oauth_consumer_token(self): """Subclasses must override this to return their OAuth consumer keys. @@ -640,7 +667,7 @@ class TwitterMixin(OAuthMixin): """ http = self.get_auth_http_client() http.fetch(self._oauth_request_token_url(callback_uri=callback_uri), - self.async_callback( + functools.partial( self._on_request_token, self._OAUTH_AUTHENTICATE_URL, None, callback)) @@ -698,7 +725,7 @@ class TwitterMixin(OAuthMixin): if args: url += "?" + urllib_parse.urlencode(args) http = self.get_auth_http_client() - http_callback = self.async_callback(self._on_twitter_request, callback) + http_callback = functools.partial(self._on_twitter_request, callback) if post_args is not None: http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), callback=http_callback) @@ -815,7 +842,7 @@ class FriendFeedMixin(OAuthMixin): args.update(oauth) if args: url += "?" + urllib_parse.urlencode(args) - callback = self.async_callback(self._on_friendfeed_request, callback) + callback = functools.partial(self._on_friendfeed_request, callback) http = self.get_auth_http_client() if post_args is not None: http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), @@ -856,6 +883,10 @@ class FriendFeedMixin(OAuthMixin): class GoogleMixin(OpenIdMixin, OAuthMixin): """Google Open ID / OAuth authentication. + *Deprecated:* New applications should use `GoogleOAuth2Mixin` + below instead of this class. As of May 19, 2014, Google has stopped + supporting registration-free authentication. + No application registration is necessary to use Google for authentication or to access Google resources on behalf of a user. @@ -926,7 +957,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin): http = self.get_auth_http_client() token = dict(key=token, secret="") http.fetch(self._oauth_access_token_url(token), - self.async_callback(self._on_access_token, callback)) + functools.partial(self._on_access_token, callback)) else: chain_future(OpenIdMixin.get_authenticated_user(self), callback) @@ -945,6 +976,19 @@ class GoogleMixin(OpenIdMixin, OAuthMixin): class GoogleOAuth2Mixin(OAuth2Mixin): """Google authentication using OAuth2. + In order to use, register your application with Google and copy the + relevant parameters to your application settings. + + * Go to the Google Dev Console at http://console.developers.google.com + * Select a project, or create a new one. + * In the sidebar on the left, select APIs & Auth. + * In the list of APIs, find the Google+ API service and set it to ON. + * In the sidebar on the left, select Credentials. + * In the OAuth section of the page, select Create New Client ID. + * Set the Redirect URI to point to your auth handler + * Copy the "Client secret" and "Client ID" to the application settings as + {"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}} + .. versionadded:: 3.2 """ _OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/auth" @@ -958,7 +1002,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin): Example usage:: - class GoogleOAuth2LoginHandler(LoginHandler, + class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): @tornado.gen.coroutine def get(self): @@ -985,7 +1029,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin): }) http.fetch(self._OAUTH_ACCESS_TOKEN_URL, - self.async_callback(self._on_access_token, callback), + functools.partial(self._on_access_token, callback), method="POST", headers={'Content-Type': 'application/x-www-form-urlencoded'}, body=body) def _on_access_token(self, future, response): @@ -1026,7 +1070,7 @@ class FacebookMixin(object): @tornado.web.asynchronous def get(self): if self.get_argument("session", None): - self.get_authenticated_user(self.async_callback(self._on_auth)) + self.get_authenticated_user(self._on_auth) return yield self.authenticate_redirect() @@ -1112,7 +1156,7 @@ class FacebookMixin(object): session = escape.json_decode(self.get_argument("session")) self.facebook_request( method="facebook.users.getInfo", - callback=self.async_callback( + callback=functools.partial( self._on_get_user_info, callback, session), session_key=session["session_key"], uids=session["uid"], @@ -1138,7 +1182,7 @@ class FacebookMixin(object): def get(self): self.facebook_request( method="stream.get", - callback=self.async_callback(self._on_stream), + callback=self._on_stream, session_key=self.current_user["session_key"]) def _on_stream(self, stream): @@ -1162,7 +1206,7 @@ class FacebookMixin(object): url = "http://api.facebook.com/restserver.php?" + \ urllib_parse.urlencode(args) http = self.get_auth_http_client() - http.fetch(url, callback=self.async_callback( + http.fetch(url, callback=functools.partial( self._parse_response, callback)) def _on_get_user_info(self, callback, session, users): @@ -1260,7 +1304,7 @@ class FacebookGraphMixin(OAuth2Mixin): fields.update(extra_fields) http.fetch(self._oauth_request_token_url(**args), - self.async_callback(self._on_access_token, redirect_uri, client_id, + functools.partial(self._on_access_token, redirect_uri, client_id, client_secret, callback, fields)) def _on_access_token(self, redirect_uri, client_id, client_secret, @@ -1277,7 +1321,7 @@ class FacebookGraphMixin(OAuth2Mixin): self.facebook_request( path="/me", - callback=self.async_callback( + callback=functools.partial( self._on_get_user_info, future, session, fields), access_token=session["access_token"], fields=",".join(fields) @@ -1344,7 +1388,7 @@ class FacebookGraphMixin(OAuth2Mixin): if all_args: url += "?" + urllib_parse.urlencode(all_args) - callback = self.async_callback(self._on_facebook_request, callback) + callback = functools.partial(self._on_facebook_request, callback) http = self.get_auth_http_client() if post_args is not None: http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), diff --git a/tornado/autoreload.py b/tornado/autoreload.py index 79cccb49..3982579a 100644 --- a/tornado/autoreload.py +++ b/tornado/autoreload.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -"""xAutomatically restart the server when a source file is modified. +"""Automatically restart the server when a source file is modified. Most applications should not access this module directly. Instead, pass the keyword argument ``autoreload=True`` to the diff --git a/tornado/concurrent.py b/tornado/concurrent.py index a9002b16..702aa352 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -40,52 +40,132 @@ class ReturnValueIgnoredError(Exception): pass -class _DummyFuture(object): +class Future(object): + """Placeholder for an asynchronous result. + + A ``Future`` encapsulates the result of an asynchronous + operation. In synchronous applications ``Futures`` are used + to wait for the result from a thread or process pool; in + Tornado they are normally used with `.IOLoop.add_future` or by + yielding them in a `.gen.coroutine`. + + `tornado.concurrent.Future` is similar to + `concurrent.futures.Future`, but not thread-safe (and therefore + faster for use with single-threaded event loops). + + In addition to ``exception`` and ``set_exception``, methods ``exc_info`` + and ``set_exc_info`` are supported to capture tracebacks in Python 2. + The traceback is automatically available in Python 3, but in the + Python 2 futures backport this information is discarded. + This functionality was previously available in a separate class + ``TracebackFuture``, which is now a deprecated alias for this class. + + .. versionchanged:: 4.0 + `tornado.concurrent.Future` is always a thread-unsafe ``Future`` + with support for the ``exc_info`` methods. Previously it would + be an alias for the thread-safe `concurrent.futures.Future` + if that package was available and fall back to the thread-unsafe + implementation if it was not. + + """ def __init__(self): self._done = False self._result = None self._exception = None + self._exc_info = None self._callbacks = [] def cancel(self): + """Cancel the operation, if possible. + + Tornado ``Futures`` do not support cancellation, so this method always + returns False. + """ return False def cancelled(self): + """Returns True if the operation has been cancelled. + + Tornado ``Futures`` do not support cancellation, so this method + always returns False. + """ return False def running(self): + """Returns True if this operation is currently running.""" return not self._done def done(self): + """Returns True if the future has finished running.""" return self._done def result(self, timeout=None): - self._check_done() - if self._exception: + """If the operation succeeded, return its result. If it failed, + re-raise its exception. + """ + if self._result is not None: + return self._result + if self._exc_info is not None: + raise_exc_info(self._exc_info) + elif self._exception is not None: raise self._exception + self._check_done() return self._result def exception(self, timeout=None): - self._check_done() - if self._exception: + """If the operation raised an exception, return the `Exception` + object. Otherwise returns None. + """ + if self._exception is not None: return self._exception else: + self._check_done() return None def add_done_callback(self, fn): + """Attaches the given callback to the `Future`. + + It will be invoked with the `Future` as its argument when the Future + has finished running and its result is available. In Tornado + consider using `.IOLoop.add_future` instead of calling + `add_done_callback` directly. + """ if self._done: fn(self) else: self._callbacks.append(fn) def set_result(self, result): + """Sets the result of a ``Future``. + + It is undefined to call any of the ``set`` methods more than once + on the same object. + """ self._result = result self._set_done() def set_exception(self, exception): + """Sets the exception of a ``Future.``""" self._exception = exception self._set_done() + def exc_info(self): + """Returns a tuple in the same format as `sys.exc_info` or None. + + .. versionadded:: 4.0 + """ + return self._exc_info + + def set_exc_info(self, exc_info): + """Sets the exception information of a ``Future.`` + + Preserves tracebacks on Python 2. + + .. versionadded:: 4.0 + """ + self._exc_info = exc_info + self.set_exception(exc_info[1]) + def _check_done(self): if not self._done: raise Exception("DummyFuture does not support blocking for results") @@ -97,38 +177,16 @@ class _DummyFuture(object): cb(self) self._callbacks = None +TracebackFuture = Future + if futures is None: - Future = _DummyFuture + FUTURES = Future else: - Future = futures.Future + FUTURES = (futures.Future, Future) -class TracebackFuture(Future): - """Subclass of `Future` which can store a traceback with - exceptions. - - The traceback is automatically available in Python 3, but in the - Python 2 futures backport this information is discarded. - """ - def __init__(self): - super(TracebackFuture, self).__init__() - self.__exc_info = None - - def exc_info(self): - return self.__exc_info - - def set_exc_info(self, exc_info): - """Traceback-aware replacement for - `~concurrent.futures.Future.set_exception`. - """ - self.__exc_info = exc_info - self.set_exception(exc_info[1]) - - def result(self, timeout=None): - if self.__exc_info is not None: - raise_exc_info(self.__exc_info) - else: - return super(TracebackFuture, self).result(timeout=timeout) +def is_future(x): + return isinstance(x, FUTURES) class DummyExecutor(object): @@ -254,10 +312,13 @@ def return_future(f): def chain_future(a, b): """Chain two futures together so that when one completes, so does the other. - The result (success or failure) of ``a`` will be copied to ``b``. + The result (success or failure) of ``a`` will be copied to ``b``, unless + ``b`` has already been completed or cancelled by the time ``a`` finishes. """ def copy(future): assert future is a + if b.done(): + return if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture) and a.exc_info() is not None): b.set_exc_info(a.exc_info()) diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index 0df7a7ee..c190ac91 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -87,7 +87,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): for curl in self._curls: curl.close() self._multi.close() - self._closed = True super(CurlAsyncHTTPClient, self).close() def fetch_impl(self, request, callback): @@ -268,6 +267,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): info["callback"](HTTPResponse( request=info["request"], code=code, headers=info["headers"], buffer=buffer, effective_url=effective_url, error=error, + reason=info['headers'].get("X-Http-Reason", None), request_time=time.time() - info["curl_start_time"], time_info=time_info)) except Exception: @@ -470,7 +470,11 @@ def _curl_header_callback(headers, header_line): header_line = header_line.strip() if header_line.startswith("HTTP/"): headers.clear() - return + try: + (__, __, reason) = httputil.parse_response_start_line(header_line) + header_line = "X-Http-Reason: %s" % reason + except httputil.HTTPInputError: + return if not header_line: return headers.parse_line(header_line) diff --git a/tornado/gen.py b/tornado/gen.py index aa931b45..9548b5f5 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -87,9 +87,9 @@ import itertools import sys import types -from tornado.concurrent import Future, TracebackFuture +from tornado.concurrent import Future, TracebackFuture, is_future, chain_future from tornado.ioloop import IOLoop -from tornado.stack_context import ExceptionStackContext, wrap +from tornado import stack_context class KeyReuseError(Exception): @@ -112,6 +112,10 @@ class ReturnValueIgnoredError(Exception): pass +class TimeoutError(Exception): + """Exception raised by ``with_timeout``.""" + + def engine(func): """Callback-oriented decorator for asynchronous generators. @@ -129,45 +133,20 @@ def engine(func): `~tornado.web.RequestHandler` :ref:`HTTP verb methods `, which use ``self.finish()`` in place of a callback argument. """ + func = _make_coroutine_wrapper(func, replace_callback=False) @functools.wraps(func) def wrapper(*args, **kwargs): - runner = None - - def handle_exception(typ, value, tb): - # if the function throws an exception before its first "yield" - # (or is not a generator at all), the Runner won't exist yet. - # However, in that case we haven't reached anything asynchronous - # yet, so we can just let the exception propagate. - if runner is not None: - return runner.handle_exception(typ, value, tb) - return False - with ExceptionStackContext(handle_exception) as deactivate: - try: - result = func(*args, **kwargs) - except (Return, StopIteration) as e: - result = getattr(e, 'value', None) - else: - if isinstance(result, types.GeneratorType): - def final_callback(value): - if value is not None: - raise ReturnValueIgnoredError( - "@gen.engine functions cannot return values: " - "%r" % (value,)) - assert value is None - deactivate() - runner = Runner(result, final_callback) - runner.run() - return - if result is not None: + future = func(*args, **kwargs) + def final_callback(future): + if future.result() is not None: raise ReturnValueIgnoredError( "@gen.engine functions cannot return values: %r" % - (result,)) - deactivate() - # no yield, so we're done + (future.result(),)) + future.add_done_callback(final_callback) return wrapper -def coroutine(func): +def coroutine(func, replace_callback=True): """Decorator for asynchronous generators. Any generator that yields objects from this module must be wrapped @@ -191,43 +170,56 @@ def coroutine(func): From the caller's perspective, ``@gen.coroutine`` is similar to the combination of ``@return_future`` and ``@gen.engine``. """ + return _make_coroutine_wrapper(func, replace_callback=True) + + +def _make_coroutine_wrapper(func, replace_callback): + """The inner workings of ``@gen.coroutine`` and ``@gen.engine``. + + The two decorators differ in their treatment of the ``callback`` + argument, so we cannot simply implement ``@engine`` in terms of + ``@coroutine``. + """ @functools.wraps(func) def wrapper(*args, **kwargs): - runner = None future = TracebackFuture() - if 'callback' in kwargs: + if replace_callback and 'callback' in kwargs: callback = kwargs.pop('callback') IOLoop.current().add_future( future, lambda future: callback(future.result())) - def handle_exception(typ, value, tb): - try: - if runner is not None and runner.handle_exception(typ, value, tb): - return True - except Exception: - typ, value, tb = sys.exc_info() - future.set_exc_info((typ, value, tb)) - return True - with ExceptionStackContext(handle_exception) as deactivate: - try: - result = func(*args, **kwargs) - except (Return, StopIteration) as e: - result = getattr(e, 'value', None) - except Exception: - deactivate() - future.set_exc_info(sys.exc_info()) + try: + result = func(*args, **kwargs) + except (Return, StopIteration) as e: + result = getattr(e, 'value', None) + except Exception: + future.set_exc_info(sys.exc_info()) + return future + else: + if isinstance(result, types.GeneratorType): + # Inline the first iteration of Runner.run. This lets us + # avoid the cost of creating a Runner when the coroutine + # never actually yields, which in turn allows us to + # use "optional" coroutines in critical path code without + # performance penalty for the synchronous case. + try: + orig_stack_contexts = stack_context._state.contexts + yielded = next(result) + if stack_context._state.contexts is not orig_stack_contexts: + yielded = TracebackFuture() + yielded.set_exception( + stack_context.StackContextInconsistentError( + 'stack_context inconsistency (probably caused ' + 'by yield within a "with StackContext" block)')) + except (StopIteration, Return) as e: + future.set_result(getattr(e, 'value', None)) + except Exception: + future.set_exc_info(sys.exc_info()) + else: + Runner(result, future, yielded) return future - else: - if isinstance(result, types.GeneratorType): - def final_callback(value): - deactivate() - future.set_result(value) - runner = Runner(result, final_callback) - runner.run() - return future - deactivate() - future.set_result(result) + future.set_result(result) return future return wrapper @@ -348,7 +340,7 @@ class WaitAll(YieldPoint): return [self.runner.pop_result(key) for key in self.keys] -class Task(YieldPoint): +def Task(func, *args, **kwargs): """Runs a single asynchronous operation. Takes a function (and optional additional arguments) and runs it with @@ -362,25 +354,25 @@ class Task(YieldPoint): func(args, callback=(yield gen.Callback(key))) result = yield gen.Wait(key) + + .. versionchanged:: 4.0 + ``gen.Task`` is now a function that returns a `.Future`, instead of + a subclass of `YieldPoint`. It still behaves the same way when + yielded. """ - def __init__(self, func, *args, **kwargs): - assert "callback" not in kwargs - self.args = args - self.kwargs = kwargs - self.func = func - - def start(self, runner): - self.runner = runner - self.key = object() - runner.register_callback(self.key) - self.kwargs["callback"] = runner.result_callback(self.key) - self.func(*self.args, **self.kwargs) - - def is_ready(self): - return self.runner.is_ready(self.key) - - def get_result(self): - return self.runner.pop_result(self.key) + future = Future() + def handle_exception(typ, value, tb): + if future.done(): + return False + future.set_exc_info((typ, value, tb)) + return True + def set_result(result): + if future.done(): + return + future.set_result(result) + with stack_context.ExceptionStackContext(handle_exception): + func(*args, callback=_argument_adapter(set_result), **kwargs) + return future class YieldFuture(YieldPoint): @@ -414,10 +406,14 @@ class YieldFuture(YieldPoint): class Multi(YieldPoint): """Runs multiple asynchronous operations in parallel. - Takes a list of ``Tasks`` or other ``YieldPoints`` and returns a list of + Takes a list of ``YieldPoints`` or ``Futures`` and returns a list of their responses. It is not necessary to call `Multi` explicitly, since the engine will do so automatically when the generator yields - a list of ``YieldPoints``. + a list of ``YieldPoints`` or a mixture of ``YieldPoints`` and ``Futures``. + + Instead of a list, the argument may also be a dictionary whose values are + Futures, in which case a parallel dictionary is returned mapping the same + keys to their results. """ def __init__(self, children): self.keys = None @@ -426,7 +422,7 @@ class Multi(YieldPoint): children = children.values() self.children = [] for i in children: - if isinstance(i, Future): + if is_future(i): i = YieldFuture(i) self.children.append(i) assert all(isinstance(i, YieldPoint) for i in self.children) @@ -450,18 +446,127 @@ class Multi(YieldPoint): return list(result) -class _NullYieldPoint(YieldPoint): - def start(self, runner): - pass +def multi_future(children): + """Wait for multiple asynchronous futures in parallel. - def is_ready(self): - return True + Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns + a new Future that resolves when all the other Futures are done. + If all the ``Futures`` succeeded, the returned Future's result is a list + of their results. If any failed, the returned Future raises the exception + of the first one to fail. - def get_result(self): - return None + Instead of a list, the argument may also be a dictionary whose values are + Futures, in which case a parallel dictionary is returned mapping the same + keys to their results. + + It is not necessary to call `multi_future` explcitly, since the engine will + do so automatically when the generator yields a list of `Futures`. + This function is faster than the `Multi` `YieldPoint` because it does not + require the creation of a stack context. + + .. versionadded:: 4.0 + """ + if isinstance(children, dict): + keys = list(children.keys()) + children = children.values() + else: + keys = None + assert all(is_future(i) for i in children) + unfinished_children = set(children) + + future = Future() + if not children: + future.set_result({} if keys is not None else []) + def callback(f): + unfinished_children.remove(f) + if not unfinished_children: + try: + result_list = [i.result() for i in children] + except Exception: + future.set_exc_info(sys.exc_info()) + else: + if keys is not None: + future.set_result(dict(zip(keys, result_list))) + else: + future.set_result(result_list) + for f in children: + f.add_done_callback(callback) + return future -_null_yield_point = _NullYieldPoint() +def maybe_future(x): + """Converts ``x`` into a `.Future`. + + If ``x`` is already a `.Future`, it is simply returned; otherwise + it is wrapped in a new `.Future`. This is suitable for use as + ``result = yield gen.maybe_future(f())`` when you don't know whether + ``f()`` returns a `.Future` or not. + """ + if is_future(x): + return x + else: + fut = Future() + fut.set_result(x) + return fut + + +def with_timeout(timeout, future, io_loop=None): + """Wraps a `.Future` in a timeout. + + Raises `TimeoutError` if the input future does not complete before + ``timeout``, which may be specified in any form allowed by + `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time + relative to `.IOLoop.time`) + + Currently only supports Futures, not other `YieldPoint` classes. + + .. versionadded:: 4.0 + """ + # TODO: allow yield points in addition to futures? + # Tricky to do with stack_context semantics. + # + # It's tempting to optimize this by cancelling the input future on timeout + # instead of creating a new one, but A) we can't know if we are the only + # one waiting on the input future, so cancelling it might disrupt other + # callers and B) concurrent futures can only be cancelled while they are + # in the queue, so cancellation cannot reliably bound our waiting time. + result = Future() + chain_future(future, result) + if io_loop is None: + io_loop = IOLoop.current() + timeout_handle = io_loop.add_timeout( + timeout, + lambda: result.set_exception(TimeoutError("Timeout"))) + if isinstance(future, Future): + # We know this future will resolve on the IOLoop, so we don't + # need the extra thread-safety of IOLoop.add_future (and we also + # don't care about StackContext here. + future.add_done_callback( + lambda future: io_loop.remove_timeout(timeout_handle)) + else: + # concurrent.futures.Futures may resolve on any thread, so we + # need to route them back to the IOLoop. + io_loop.add_future( + future, lambda future: io_loop.remove_timeout(timeout_handle)) + return result + + +_null_future = Future() +_null_future.set_result(None) + +moment = Future() +moment.__doc__ = \ + """A special object which may be yielded to allow the IOLoop to run for +one iteration. + +This is not needed in normal use but it can be helpful in long-running +coroutines that are likely to yield Futures that are ready instantly. + +Usage: ``yield gen.moment`` + +.. versionadded:: 4.0 +""" +moment.set_result(None) class Runner(object): @@ -469,35 +574,55 @@ class Runner(object): Maintains information about pending callbacks and their results. - ``final_callback`` is run after the generator exits. + The results of the generator are stored in ``result_future`` (a + `.TracebackFuture`) """ - def __init__(self, gen, final_callback): + def __init__(self, gen, result_future, first_yielded): self.gen = gen - self.final_callback = final_callback - self.yield_point = _null_yield_point - self.pending_callbacks = set() - self.results = {} + self.result_future = result_future + self.future = _null_future + self.yield_point = None + self.pending_callbacks = None + self.results = None self.running = False self.finished = False - self.exc_info = None self.had_exception = False + self.io_loop = IOLoop.current() + # For efficiency, we do not create a stack context until we + # reach a YieldPoint (stack contexts are required for the historical + # semantics of YieldPoints, but not for Futures). When we have + # done so, this field will be set and must be called at the end + # of the coroutine. + self.stack_context_deactivate = None + if self.handle_yield(first_yielded): + self.run() def register_callback(self, key): """Adds ``key`` to the list of callbacks.""" + if self.pending_callbacks is None: + # Lazily initialize the old-style YieldPoint data structures. + self.pending_callbacks = set() + self.results = {} if key in self.pending_callbacks: raise KeyReuseError("key %r is already pending" % (key,)) self.pending_callbacks.add(key) def is_ready(self, key): """Returns true if a result is available for ``key``.""" - if key not in self.pending_callbacks: + if self.pending_callbacks is None or key not in self.pending_callbacks: raise UnknownKeyError("key %r is not pending" % (key,)) return key in self.results def set_result(self, key, result): """Sets the result for ``key`` and attempts to resume the generator.""" self.results[key] = result - self.run() + if self.yield_point is not None and self.yield_point.is_ready(): + try: + self.future.set_result(self.yield_point.get_result()) + except: + self.future.set_exc_info(sys.exc_info()) + self.yield_point = None + self.run() def pop_result(self, key): """Returns the result for ``key`` and unregisters it.""" @@ -513,25 +638,27 @@ class Runner(object): try: self.running = True while True: - if self.exc_info is None: - try: - if not self.yield_point.is_ready(): - return - next = self.yield_point.get_result() - self.yield_point = None - except Exception: - self.exc_info = sys.exc_info() + future = self.future + if not future.done(): + return + self.future = None try: - if self.exc_info is not None: + orig_stack_contexts = stack_context._state.contexts + try: + value = future.result() + except Exception: self.had_exception = True - exc_info = self.exc_info - self.exc_info = None - yielded = self.gen.throw(*exc_info) + yielded = self.gen.throw(*sys.exc_info()) else: - yielded = self.gen.send(next) + yielded = self.gen.send(value) + if stack_context._state.contexts is not orig_stack_contexts: + self.gen.throw( + stack_context.StackContextInconsistentError( + 'stack_context inconsistency (probably caused ' + 'by yield within a "with StackContext" block)')) except (StopIteration, Return) as e: self.finished = True - self.yield_point = _null_yield_point + self.future = _null_future if self.pending_callbacks and not self.had_exception: # If we ran cleanly without waiting on all callbacks # raise an error (really more of a warning). If we @@ -540,46 +667,105 @@ class Runner(object): raise LeakedCallbackError( "finished without waiting for callbacks %r" % self.pending_callbacks) - self.final_callback(getattr(e, 'value', None)) - self.final_callback = None + self.result_future.set_result(getattr(e, 'value', None)) + self.result_future = None + self._deactivate_stack_context() return except Exception: self.finished = True - self.yield_point = _null_yield_point - raise - if isinstance(yielded, (list, dict)): - yielded = Multi(yielded) - elif isinstance(yielded, Future): - yielded = YieldFuture(yielded) - if isinstance(yielded, YieldPoint): - self.yield_point = yielded - try: - self.yield_point.start(self) - except Exception: - self.exc_info = sys.exc_info() - else: - self.exc_info = (BadYieldError( - "yielded unknown object %r" % (yielded,)),) + self.future = _null_future + self.result_future.set_exc_info(sys.exc_info()) + self.result_future = None + self._deactivate_stack_context() + return + if not self.handle_yield(yielded): + return finally: self.running = False - def result_callback(self, key): - def inner(*args, **kwargs): - if kwargs or len(args) > 1: - result = Arguments(args, kwargs) - elif args: - result = args[0] + def handle_yield(self, yielded): + if isinstance(yielded, list): + if all(is_future(f) for f in yielded): + yielded = multi_future(yielded) else: - result = None - self.set_result(key, result) - return wrap(inner) + yielded = Multi(yielded) + elif isinstance(yielded, dict): + if all(is_future(f) for f in yielded.values()): + yielded = multi_future(yielded) + else: + yielded = Multi(yielded) + + if isinstance(yielded, YieldPoint): + self.future = TracebackFuture() + def start_yield_point(): + try: + yielded.start(self) + if yielded.is_ready(): + self.future.set_result( + yielded.get_result()) + else: + self.yield_point = yielded + except Exception: + self.future = TracebackFuture() + self.future.set_exc_info(sys.exc_info()) + if self.stack_context_deactivate is None: + # Start a stack context if this is the first + # YieldPoint we've seen. + with stack_context.ExceptionStackContext( + self.handle_exception) as deactivate: + self.stack_context_deactivate = deactivate + def cb(): + start_yield_point() + self.run() + self.io_loop.add_callback(cb) + return False + else: + start_yield_point() + elif is_future(yielded): + self.future = yielded + if not self.future.done() or self.future is moment: + self.io_loop.add_future( + self.future, lambda f: self.run()) + return False + else: + self.future = TracebackFuture() + self.future.set_exception(BadYieldError( + "yielded unknown object %r" % (yielded,))) + return True + + def result_callback(self, key): + return stack_context.wrap(_argument_adapter( + functools.partial(self.set_result, key))) def handle_exception(self, typ, value, tb): if not self.running and not self.finished: - self.exc_info = (typ, value, tb) + self.future = TracebackFuture() + self.future.set_exc_info((typ, value, tb)) self.run() return True else: return False + def _deactivate_stack_context(self): + if self.stack_context_deactivate is not None: + self.stack_context_deactivate() + self.stack_context_deactivate = None + Arguments = collections.namedtuple('Arguments', ['args', 'kwargs']) + + +def _argument_adapter(callback): + """Returns a function that when invoked runs ``callback`` with one arg. + + If the function returned by this function is called with exactly + one argument, that argument is passed to ``callback``. Otherwise + the args tuple and kwargs dict are wrapped in an `Arguments` object. + """ + def wrapper(*args, **kwargs): + if kwargs or len(args) > 1: + callback(Arguments(args, kwargs)) + elif args: + callback(args[0]) + else: + callback(None) + return wrapper diff --git a/tornado/http1connection.py b/tornado/http1connection.py new file mode 100644 index 00000000..c43675a1 --- /dev/null +++ b/tornado/http1connection.py @@ -0,0 +1,650 @@ +#!/usr/bin/env python +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Client and server implementations of HTTP/1.x. + +.. versionadded:: 4.0 +""" + +from __future__ import absolute_import, division, print_function, with_statement + +from tornado.concurrent import Future +from tornado.escape import native_str, utf8 +from tornado import gen +from tornado import httputil +from tornado import iostream +from tornado.log import gen_log, app_log +from tornado import stack_context +from tornado.util import GzipDecompressor + + +class _QuietException(Exception): + def __init__(self): + pass + +class _ExceptionLoggingContext(object): + """Used with the ``with`` statement when calling delegate methods to + log any exceptions with the given logger. Any exceptions caught are + converted to _QuietException + """ + def __init__(self, logger): + self.logger = logger + + def __enter__(self): + pass + + def __exit__(self, typ, value, tb): + if value is not None: + self.logger.error("Uncaught exception", exc_info=(typ, value, tb)) + raise _QuietException + +class HTTP1ConnectionParameters(object): + """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`. + """ + def __init__(self, no_keep_alive=False, chunk_size=None, + max_header_size=None, header_timeout=None, max_body_size=None, + body_timeout=None, use_gzip=False): + """ + :arg bool no_keep_alive: If true, always close the connection after + one request. + :arg int chunk_size: how much data to read into memory at once + :arg int max_header_size: maximum amount of data for HTTP headers + :arg float header_timeout: how long to wait for all headers (seconds) + :arg int max_body_size: maximum amount of data for body + :arg float body_timeout: how long to wait while reading body (seconds) + :arg bool use_gzip: if true, decode incoming ``Content-Encoding: gzip`` + """ + self.no_keep_alive = no_keep_alive + self.chunk_size = chunk_size or 65536 + self.max_header_size = max_header_size or 65536 + self.header_timeout = header_timeout + self.max_body_size = max_body_size + self.body_timeout = body_timeout + self.use_gzip = use_gzip + + +class HTTP1Connection(httputil.HTTPConnection): + """Implements the HTTP/1.x protocol. + + This class can be on its own for clients, or via `HTTP1ServerConnection` + for servers. + """ + def __init__(self, stream, is_client, params=None, context=None): + """ + :arg stream: an `.IOStream` + :arg bool is_client: client or server + :arg params: a `.HTTP1ConnectionParameters` instance or ``None`` + :arg context: an opaque application-defined object that can be accessed + as ``connection.context``. + """ + self.is_client = is_client + self.stream = stream + if params is None: + params = HTTP1ConnectionParameters() + self.params = params + self.context = context + self.no_keep_alive = params.no_keep_alive + # The body limits can be altered by the delegate, so save them + # here instead of just referencing self.params later. + self._max_body_size = (self.params.max_body_size or + self.stream.max_buffer_size) + self._body_timeout = self.params.body_timeout + # _write_finished is set to True when finish() has been called, + # i.e. there will be no more data sent. Data may still be in the + # stream's write buffer. + self._write_finished = False + # True when we have read the entire incoming body. + self._read_finished = False + # _finish_future resolves when all data has been written and flushed + # to the IOStream. + self._finish_future = Future() + # If true, the connection should be closed after this request + # (after the response has been written in the server side, + # and after it has been read in the client) + self._disconnect_on_finish = False + self._clear_callbacks() + # Save the start lines after we read or write them; they + # affect later processing (e.g. 304 responses and HEAD methods + # have content-length but no bodies) + self._request_start_line = None + self._response_start_line = None + self._request_headers = None + # True if we are writing output with chunked encoding. + self._chunking_output = None + # While reading a body with a content-length, this is the + # amount left to read. + self._expected_content_remaining = None + # A Future for our outgoing writes, returned by IOStream.write. + self._pending_write = None + + def read_response(self, delegate): + """Read a single HTTP response. + + Typical client-mode usage is to write a request using `write_headers`, + `write`, and `finish`, and then call ``read_response``. + + :arg delegate: a `.HTTPMessageDelegate` + + Returns a `.Future` that resolves to None after the full response has + been read. + """ + if self.params.use_gzip: + delegate = _GzipMessageDelegate(delegate, self.params.chunk_size) + return self._read_message(delegate) + + @gen.coroutine + def _read_message(self, delegate): + need_delegate_close = False + try: + header_future = self.stream.read_until_regex( + b"\r?\n\r?\n", + max_bytes=self.params.max_header_size) + if self.params.header_timeout is None: + header_data = yield header_future + else: + try: + header_data = yield gen.with_timeout( + self.stream.io_loop.time() + self.params.header_timeout, + header_future, + io_loop=self.stream.io_loop) + except gen.TimeoutError: + self.close() + raise gen.Return(False) + start_line, headers = self._parse_headers(header_data) + if self.is_client: + start_line = httputil.parse_response_start_line(start_line) + self._response_start_line = start_line + else: + start_line = httputil.parse_request_start_line(start_line) + self._request_start_line = start_line + self._request_headers = headers + + self._disconnect_on_finish = not self._can_keep_alive( + start_line, headers) + need_delegate_close = True + with _ExceptionLoggingContext(app_log): + header_future = delegate.headers_received(start_line, headers) + if header_future is not None: + yield header_future + if self.stream is None: + # We've been detached. + need_delegate_close = False + raise gen.Return(False) + skip_body = False + if self.is_client: + if (self._request_start_line is not None and + self._request_start_line.method == 'HEAD'): + skip_body = True + code = start_line.code + if code == 304: + skip_body = True + if code >= 100 and code < 200: + # TODO: client delegates will get headers_received twice + # in the case of a 100-continue. Document or change? + yield self._read_message(delegate) + else: + if (headers.get("Expect") == "100-continue" and + not self._write_finished): + self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") + if not skip_body: + body_future = self._read_body(headers, delegate) + if body_future is not None: + if self._body_timeout is None: + yield body_future + else: + try: + yield gen.with_timeout( + self.stream.io_loop.time() + self._body_timeout, + body_future, self.stream.io_loop) + except gen.TimeoutError: + gen_log.info("Timeout reading body from %s", + self.context) + self.stream.close() + raise gen.Return(False) + self._read_finished = True + if not self._write_finished or self.is_client: + need_delegate_close = False + with _ExceptionLoggingContext(app_log): + delegate.finish() + # If we're waiting for the application to produce an asynchronous + # response, and we're not detached, register a close callback + # on the stream (we didn't need one while we were reading) + if (not self._finish_future.done() and + self.stream is not None and + not self.stream.closed()): + self.stream.set_close_callback(self._on_connection_close) + yield self._finish_future + if self.is_client and self._disconnect_on_finish: + self.close() + if self.stream is None: + raise gen.Return(False) + except httputil.HTTPInputError as e: + gen_log.info("Malformed HTTP message from %s: %s", + self.context, e) + self.close() + raise gen.Return(False) + finally: + if need_delegate_close: + with _ExceptionLoggingContext(app_log): + delegate.on_connection_close() + self._clear_callbacks() + raise gen.Return(True) + + def _clear_callbacks(self): + """Clears the callback attributes. + + This allows the request handler to be garbage collected more + quickly in CPython by breaking up reference cycles. + """ + self._write_callback = None + self._write_future = None + self._close_callback = None + if self.stream is not None: + self.stream.set_close_callback(None) + + def set_close_callback(self, callback): + """Sets a callback that will be run when the connection is closed. + + .. deprecated:: 4.0 + Use `.HTTPMessageDelegate.on_connection_close` instead. + """ + self._close_callback = stack_context.wrap(callback) + + def _on_connection_close(self): + # Note that this callback is only registered on the IOStream + # when we have finished reading the request and are waiting for + # the application to produce its response. + if self._close_callback is not None: + callback = self._close_callback + self._close_callback = None + callback() + if not self._finish_future.done(): + self._finish_future.set_result(None) + self._clear_callbacks() + + def close(self): + if self.stream is not None: + self.stream.close() + self._clear_callbacks() + if not self._finish_future.done(): + self._finish_future.set_result(None) + + def detach(self): + """Take control of the underlying stream. + + Returns the underlying `.IOStream` object and stops all further + HTTP processing. May only be called during + `.HTTPMessageDelegate.headers_received`. Intended for implementing + protocols like websockets that tunnel over an HTTP handshake. + """ + self._clear_callbacks() + stream = self.stream + self.stream = None + return stream + + def set_body_timeout(self, timeout): + """Sets the body timeout for a single request. + + Overrides the value from `.HTTP1ConnectionParameters`. + """ + self._body_timeout = timeout + + def set_max_body_size(self, max_body_size): + """Sets the body size limit for a single request. + + Overrides the value from `.HTTP1ConnectionParameters`. + """ + self._max_body_size = max_body_size + + def write_headers(self, start_line, headers, chunk=None, callback=None): + """Implements `.HTTPConnection.write_headers`.""" + if self.is_client: + self._request_start_line = start_line + # Client requests with a non-empty body must have either a + # Content-Length or a Transfer-Encoding. + self._chunking_output = ( + start_line.method in ('POST', 'PUT', 'PATCH') and + 'Content-Length' not in headers and + 'Transfer-Encoding' not in headers) + else: + self._response_start_line = start_line + self._chunking_output = ( + # TODO: should this use + # self._request_start_line.version or + # start_line.version? + self._request_start_line.version == 'HTTP/1.1' and + # 304 responses have no body (not even a zero-length body), and so + # should not have either Content-Length or Transfer-Encoding. + # headers. + start_line.code != 304 and + # No need to chunk the output if a Content-Length is specified. + 'Content-Length' not in headers and + # Applications are discouraged from touching Transfer-Encoding, + # but if they do, leave it alone. + 'Transfer-Encoding' not in headers) + # If a 1.0 client asked for keep-alive, add the header. + if (self._request_start_line.version == 'HTTP/1.0' and + (self._request_headers.get('Connection', '').lower() + == 'keep-alive')): + headers['Connection'] = 'Keep-Alive' + if self._chunking_output: + headers['Transfer-Encoding'] = 'chunked' + if (not self.is_client and + (self._request_start_line.method == 'HEAD' or + start_line.code == 304)): + self._expected_content_remaining = 0 + elif 'Content-Length' in headers: + self._expected_content_remaining = int(headers['Content-Length']) + else: + self._expected_content_remaining = None + lines = [utf8("%s %s %s" % start_line)] + lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()]) + for line in lines: + if b'\n' in line: + raise ValueError('Newline in header: ' + repr(line)) + future = None + if self.stream.closed(): + future = self._write_future = Future() + future.set_exception(iostream.StreamClosedError()) + else: + if callback is not None: + self._write_callback = stack_context.wrap(callback) + else: + future = self._write_future = Future() + data = b"\r\n".join(lines) + b"\r\n\r\n" + if chunk: + data += self._format_chunk(chunk) + self._pending_write = self.stream.write(data) + self._pending_write.add_done_callback(self._on_write_complete) + return future + + def _format_chunk(self, chunk): + if self._expected_content_remaining is not None: + self._expected_content_remaining -= len(chunk) + if self._expected_content_remaining < 0: + # Close the stream now to stop further framing errors. + self.stream.close() + raise httputil.HTTPOutputError( + "Tried to write more data than Content-Length") + if self._chunking_output and chunk: + # Don't write out empty chunks because that means END-OF-STREAM + # with chunked encoding + return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n" + else: + return chunk + + def write(self, chunk, callback=None): + """Implements `.HTTPConnection.write`. + + For backwards compatibility is is allowed but deprecated to + skip `write_headers` and instead call `write()` with a + pre-encoded header block. + """ + future = None + if self.stream.closed(): + future = self._write_future = Future() + self._write_future.set_exception(iostream.StreamClosedError()) + else: + if callback is not None: + self._write_callback = stack_context.wrap(callback) + else: + future = self._write_future = Future() + self._pending_write = self.stream.write(self._format_chunk(chunk)) + self._pending_write.add_done_callback(self._on_write_complete) + return future + + def finish(self): + """Implements `.HTTPConnection.finish`.""" + if (self._expected_content_remaining is not None and + self._expected_content_remaining != 0 and + not self.stream.closed()): + self.stream.close() + raise httputil.HTTPOutputError( + "Tried to write %d bytes less than Content-Length" % + self._expected_content_remaining) + if self._chunking_output: + if not self.stream.closed(): + self._pending_write = self.stream.write(b"0\r\n\r\n") + self._pending_write.add_done_callback(self._on_write_complete) + self._write_finished = True + # If the app finished the request while we're still reading, + # divert any remaining data away from the delegate and + # close the connection when we're done sending our response. + # Closing the connection is the only way to avoid reading the + # whole input body. + if not self._read_finished: + self._disconnect_on_finish = True + # No more data is coming, so instruct TCP to send any remaining + # data immediately instead of waiting for a full packet or ack. + self.stream.set_nodelay(True) + if self._pending_write is None: + self._finish_request(None) + else: + self._pending_write.add_done_callback(self._finish_request) + + def _on_write_complete(self, future): + if self._write_callback is not None: + callback = self._write_callback + self._write_callback = None + self.stream.io_loop.add_callback(callback) + if self._write_future is not None: + future = self._write_future + self._write_future = None + future.set_result(None) + + def _can_keep_alive(self, start_line, headers): + if self.params.no_keep_alive: + return False + connection_header = headers.get("Connection") + if connection_header is not None: + connection_header = connection_header.lower() + if start_line.version == "HTTP/1.1": + return connection_header != "close" + elif ("Content-Length" in headers + or start_line.method in ("HEAD", "GET")): + return connection_header == "keep-alive" + return False + + def _finish_request(self, future): + self._clear_callbacks() + if not self.is_client and self._disconnect_on_finish: + self.close() + return + # Turn Nagle's algorithm back on, leaving the stream in its + # default state for the next request. + self.stream.set_nodelay(False) + if not self._finish_future.done(): + self._finish_future.set_result(None) + + def _parse_headers(self, data): + data = native_str(data.decode('latin1')) + eol = data.find("\r\n") + start_line = data[:eol] + try: + headers = httputil.HTTPHeaders.parse(data[eol:]) + except ValueError: + # probably form split() if there was no ':' in the line + raise httputil.HTTPInputError("Malformed HTTP headers: %r" % + data[eol:100]) + return start_line, headers + + def _read_body(self, headers, delegate): + content_length = headers.get("Content-Length") + if content_length: + content_length = int(content_length) + if content_length > self._max_body_size: + raise httputil.HTTPInputError("Content-Length too long") + return self._read_fixed_body(content_length, delegate) + if headers.get("Transfer-Encoding") == "chunked": + return self._read_chunked_body(delegate) + if self.is_client: + return self._read_body_until_close(delegate) + return None + + @gen.coroutine + def _read_fixed_body(self, content_length, delegate): + while content_length > 0: + body = yield self.stream.read_bytes( + min(self.params.chunk_size, content_length), partial=True) + content_length -= len(body) + if not self._write_finished or self.is_client: + with _ExceptionLoggingContext(app_log): + yield gen.maybe_future(delegate.data_received(body)) + + @gen.coroutine + def _read_chunked_body(self, delegate): + # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 + total_size = 0 + while True: + chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64) + chunk_len = int(chunk_len.strip(), 16) + if chunk_len == 0: + return + total_size += chunk_len + if total_size > self._max_body_size: + raise httputil.HTTPInputError("chunked body too large") + bytes_to_read = chunk_len + while bytes_to_read: + chunk = yield self.stream.read_bytes( + min(bytes_to_read, self.params.chunk_size), partial=True) + bytes_to_read -= len(chunk) + if not self._write_finished or self.is_client: + with _ExceptionLoggingContext(app_log): + yield gen.maybe_future(delegate.data_received(chunk)) + # chunk ends with \r\n + crlf = yield self.stream.read_bytes(2) + assert crlf == b"\r\n" + + @gen.coroutine + def _read_body_until_close(self, delegate): + body = yield self.stream.read_until_close() + if not self._write_finished or self.is_client: + with _ExceptionLoggingContext(app_log): + delegate.data_received(body) + + +class _GzipMessageDelegate(httputil.HTTPMessageDelegate): + """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``. + """ + def __init__(self, delegate, chunk_size): + self._delegate = delegate + self._chunk_size = chunk_size + self._decompressor = None + + def headers_received(self, start_line, headers): + if headers.get("Content-Encoding") == "gzip": + self._decompressor = GzipDecompressor() + # Downstream delegates will only see uncompressed data, + # so rename the content-encoding header. + # (but note that curl_httpclient doesn't do this). + headers.add("X-Consumed-Content-Encoding", + headers["Content-Encoding"]) + del headers["Content-Encoding"] + return self._delegate.headers_received(start_line, headers) + + @gen.coroutine + def data_received(self, chunk): + if self._decompressor: + compressed_data = chunk + while compressed_data: + decompressed = self._decompressor.decompress( + compressed_data, self._chunk_size) + if decompressed: + yield gen.maybe_future( + self._delegate.data_received(decompressed)) + compressed_data = self._decompressor.unconsumed_tail + else: + yield gen.maybe_future(self._delegate.data_received(chunk)) + + def finish(self): + if self._decompressor is not None: + tail = self._decompressor.flush() + if tail: + # I believe the tail will always be empty (i.e. + # decompress will return all it can). The purpose + # of the flush call is to detect errors such + # as truncated input. But in case it ever returns + # anything, treat it as an extra chunk + self._delegate.data_received(tail) + return self._delegate.finish() + + +class HTTP1ServerConnection(object): + """An HTTP/1.x server.""" + def __init__(self, stream, params=None, context=None): + """ + :arg stream: an `.IOStream` + :arg params: a `.HTTP1ConnectionParameters` or None + :arg context: an opaque application-defined object that is accessible + as ``connection.context`` + """ + self.stream = stream + if params is None: + params = HTTP1ConnectionParameters() + self.params = params + self.context = context + self._serving_future = None + + @gen.coroutine + def close(self): + """Closes the connection. + + Returns a `.Future` that resolves after the serving loop has exited. + """ + self.stream.close() + # Block until the serving loop is done, but ignore any exceptions + # (start_serving is already responsible for logging them). + try: + yield self._serving_future + except Exception: + pass + + def start_serving(self, delegate): + """Starts serving requests on this connection. + + :arg delegate: a `.HTTPServerConnectionDelegate` + """ + assert isinstance(delegate, httputil.HTTPServerConnectionDelegate) + self._serving_future = self._server_request_loop(delegate) + # Register the future on the IOLoop so its errors get logged. + self.stream.io_loop.add_future(self._serving_future, + lambda f: f.result()) + + @gen.coroutine + def _server_request_loop(self, delegate): + try: + while True: + conn = HTTP1Connection(self.stream, False, + self.params, self.context) + request_delegate = delegate.start_request(self, conn) + try: + ret = yield conn.read_response(request_delegate) + except (iostream.StreamClosedError, + iostream.UnsatisfiableReadError): + return + except _QuietException: + # This exception was already logged. + conn.close() + return + except Exception: + gen_log.error("Uncaught exception", exc_info=True) + conn.close() + return + if not ret: + return + yield gen.moment + finally: + delegate.on_close(self) diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 9b42d401..48731c15 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -25,6 +25,11 @@ to switch to ``curl_httpclient`` for reasons such as the following: Note that if you are using ``curl_httpclient``, it is highly recommended that you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum supported version is 7.18.2, and the recommended version is 7.21.1 or newer. +It is highly recommended that your ``libcurl`` installation is built with +asynchronous DNS resolver (threaded or c-ares), otherwise you may encounter +various problems with request timeouts (for more information, see +http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS +and comments in curl_httpclient.py). """ from __future__ import absolute_import, division, print_function, with_statement @@ -34,7 +39,7 @@ import time import weakref from tornado.concurrent import TracebackFuture -from tornado.escape import utf8 +from tornado.escape import utf8, native_str from tornado import httputil, stack_context from tornado.ioloop import IOLoop from tornado.util import Configurable @@ -105,10 +110,21 @@ class AsyncHTTPClient(Configurable): actually creates an instance of an implementation-specific subclass, and instances are reused as a kind of pseudo-singleton (one per `.IOLoop`). The keyword argument ``force_instance=True`` - can be used to suppress this singleton behavior. Constructor - arguments other than ``io_loop`` and ``force_instance`` are - deprecated. The implementation subclass as well as arguments to - its constructor can be set with the static method `configure()` + can be used to suppress this singleton behavior. Unless + ``force_instance=True`` is used, no arguments other than + ``io_loop`` should be passed to the `AsyncHTTPClient` constructor. + The implementation subclass as well as arguments to its + constructor can be set with the static method `configure()` + + All `AsyncHTTPClient` implementations support a ``defaults`` + keyword argument, which can be used to set default values for + `HTTPRequest` attributes. For example:: + + AsyncHTTPClient.configure( + None, defaults=dict(user_agent="MyUserAgent")) + # or with force_instance: + client = AsyncHTTPClient(force_instance=True, + defaults=dict(user_agent="MyUserAgent")) """ @classmethod def configurable_base(cls): @@ -141,6 +157,7 @@ class AsyncHTTPClient(Configurable): self.defaults = dict(HTTPRequest._DEFAULTS) if defaults is not None: self.defaults.update(defaults) + self._closed = False def close(self): """Destroys this HTTP client, freeing any file descriptors used. @@ -155,6 +172,7 @@ class AsyncHTTPClient(Configurable): ``close()``. """ + self._closed = True if self._async_clients().get(self.io_loop) is self: del self._async_clients()[self.io_loop] @@ -166,7 +184,7 @@ class AsyncHTTPClient(Configurable): kwargs: ``HTTPRequest(request, **kwargs)`` This method returns a `.Future` whose result is an - `HTTPResponse`. The ``Future`` wil raise an `HTTPError` if + `HTTPResponse`. The ``Future`` will raise an `HTTPError` if the request returned a non-200 response code. If a ``callback`` is given, it will be invoked with the `HTTPResponse`. @@ -174,6 +192,8 @@ class AsyncHTTPClient(Configurable): Instead, you must check the response's ``error`` attribute or call its `~HTTPResponse.rethrow` method. """ + if self._closed: + raise RuntimeError("fetch() called on closed AsyncHTTPClient") if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) # We may modify this (to add Host, Accept-Encoding, etc), @@ -259,14 +279,27 @@ class HTTPRequest(object): proxy_password=None, allow_nonstandard_methods=None, validate_cert=None, ca_certs=None, allow_ipv6=None, - client_key=None, client_cert=None): + client_key=None, client_cert=None, body_producer=None, + expect_100_continue=False): r"""All parameters except ``url`` are optional. :arg string url: URL to fetch :arg string method: HTTP method, e.g. "GET" or "POST" :arg headers: Additional HTTP headers to pass on the request - :arg body: HTTP body to pass on the request :type headers: `~tornado.httputil.HTTPHeaders` or `dict` + :arg body: HTTP request body as a string (byte or unicode; if unicode + the utf-8 encoding will be used) + :arg body_producer: Callable used for lazy/asynchronous request bodies. + It is called with one argument, a ``write`` function, and should + return a `.Future`. It should call the write function with new + data as it becomes available. The write function returns a + `.Future` which can be used for flow control. + Only one of ``body`` and ``body_producer`` may + be specified. ``body_producer`` is not supported on + ``curl_httpclient``. When using ``body_producer`` it is recommended + to pass a ``Content-Length`` in the headers as otherwise chunked + encoding will be used, and many servers do not support chunked + encoding on requests. New in Tornado 4.0 :arg string auth_username: Username for HTTP authentication :arg string auth_password: Password for HTTP authentication :arg string auth_mode: Authentication mode; default is "basic". @@ -319,6 +352,11 @@ class HTTPRequest(object): note below when used with ``curl_httpclient``. :arg string client_cert: Filename for client SSL certificate, if any. See note below when used with ``curl_httpclient``. + :arg bool expect_100_continue: If true, send the + ``Expect: 100-continue`` header and wait for a continue response + before sending the request body. Only supported with + simple_httpclient. + .. note:: @@ -334,6 +372,9 @@ class HTTPRequest(object): .. versionadded:: 3.1 The ``auth_mode`` argument. + + .. versionadded:: 4.0 + The ``body_producer`` and ``expect_100_continue`` arguments. """ # Note that some of these attributes go through property setters # defined below. @@ -348,6 +389,7 @@ class HTTPRequest(object): self.url = url self.method = method self.body = body + self.body_producer = body_producer self.auth_username = auth_username self.auth_password = auth_password self.auth_mode = auth_mode @@ -367,6 +409,7 @@ class HTTPRequest(object): self.allow_ipv6 = allow_ipv6 self.client_key = client_key self.client_cert = client_cert + self.expect_100_continue = expect_100_continue self.start_time = time.time() @property @@ -388,6 +431,14 @@ class HTTPRequest(object): def body(self, value): self._body = utf8(value) + @property + def body_producer(self): + return self._body_producer + + @body_producer.setter + def body_producer(self, value): + self._body_producer = stack_context.wrap(value) + @property def streaming_callback(self): return self._streaming_callback @@ -423,8 +474,6 @@ class HTTPResponse(object): * code: numeric HTTP status code, e.g. 200 or 404 * reason: human-readable reason phrase describing the status code - (with curl_httpclient, this is a default value rather than the - server's actual response) * headers: `tornado.httputil.HTTPHeaders` object @@ -466,7 +515,8 @@ class HTTPResponse(object): self.effective_url = effective_url if error is None: if self.code < 200 or self.code >= 300: - self.error = HTTPError(self.code, response=self) + self.error = HTTPError(self.code, message=self.reason, + response=self) else: self.error = None else: @@ -556,7 +606,7 @@ def main(): if options.print_headers: print(response.headers) if options.print_body: - print(response.body) + print(native_str(response.body)) client.close() if __name__ == "__main__": diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 34e7b768..277de588 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -20,70 +20,55 @@ Typical applications have little direct interaction with the `HTTPServer` class except to start a server at the beginning of the process (and even that is often done indirectly via `tornado.web.Application.listen`). -This module also defines the `HTTPRequest` class which is exposed via -`tornado.web.RequestHandler.request`. +.. versionchanged:: 4.0 + + The ``HTTPRequest`` class that used to live in this module has been moved + to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias. """ from __future__ import absolute_import, division, print_function, with_statement import socket -import ssl -import time -import copy -from tornado.escape import native_str, parse_qs_bytes +from tornado.escape import native_str +from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters +from tornado import gen from tornado import httputil from tornado import iostream -from tornado.log import gen_log from tornado import netutil from tornado.tcpserver import TCPServer -from tornado import stack_context -from tornado.util import bytes_type - -try: - import Cookie # py2 -except ImportError: - import http.cookies as Cookie # py3 -class HTTPServer(TCPServer): +class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. - A server is defined by a request callback that takes an HTTPRequest - instance as an argument and writes a valid HTTP response with - `HTTPRequest.write`. `HTTPRequest.finish` finishes the request (but does - not necessarily close the connection in the case of HTTP/1.1 keep-alive - requests). A simple example server that echoes back the URI you - requested:: + A server is defined by either a request callback that takes a + `.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate` + instance. + + A simple example server that echoes back the URI you requested:: import tornado.httpserver import tornado.ioloop def handle_request(request): message = "You requested %s\n" % request.uri - request.write("HTTP/1.1 200 OK\r\nContent-Length: %d\r\n\r\n%s" % ( - len(message), message)) - request.finish() + request.connection.write_headers( + httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'), + {"Content-Length": str(len(message))}) + request.connection.write(message) + request.connection.finish() http_server = tornado.httpserver.HTTPServer(handle_request) http_server.listen(8888) tornado.ioloop.IOLoop.instance().start() - `HTTPServer` is a very basic connection handler. It parses the request - headers and body, but the request callback is responsible for producing - the response exactly as it will appear on the wire. This affords - maximum flexibility for applications to implement whatever parts - of HTTP responses are required. + Applications should use the methods of `.HTTPConnection` to write + their response. `HTTPServer` supports keep-alive connections by default (automatically for HTTP/1.1, or for HTTP/1.0 when the client - requests ``Connection: keep-alive``). This means that the request - callback must generate a properly-framed response, using either - the ``Content-Length`` header or ``Transfer-Encoding: chunked``. - Applications that are unable to frame their responses properly - should instead return a ``Connection: close`` header in each - response and pass ``no_keep_alive=True`` to the `HTTPServer` - constructor. + requests ``Connection: keep-alive``). If ``xheaders`` is ``True``, we support the ``X-Real-Ip``/``X-Forwarded-For`` and @@ -143,407 +128,169 @@ class HTTPServer(TCPServer): servers if you want to create your listening sockets in some way other than `tornado.netutil.bind_sockets`. + .. versionchanged:: 4.0 + Added ``gzip``, ``chunk_size``, ``max_header_size``, + ``idle_connection_timeout``, ``body_timeout``, ``max_body_size`` + arguments. Added support for `.HTTPServerConnectionDelegate` + instances as ``request_callback``. """ def __init__(self, request_callback, no_keep_alive=False, io_loop=None, - xheaders=False, ssl_options=None, protocol=None, **kwargs): + xheaders=False, ssl_options=None, protocol=None, gzip=False, + chunk_size=None, max_header_size=None, + idle_connection_timeout=None, body_timeout=None, + max_body_size=None, max_buffer_size=None): self.request_callback = request_callback self.no_keep_alive = no_keep_alive self.xheaders = xheaders self.protocol = protocol + self.conn_params = HTTP1ConnectionParameters( + use_gzip=gzip, + chunk_size=chunk_size, + max_header_size=max_header_size, + header_timeout=idle_connection_timeout or 3600, + max_body_size=max_body_size, + body_timeout=body_timeout) TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options, - **kwargs) + max_buffer_size=max_buffer_size, + read_chunk_size=chunk_size) + self._connections = set() + + @gen.coroutine + def close_all_connections(self): + while self._connections: + # Peek at an arbitrary element of the set + conn = next(iter(self._connections)) + yield conn.close() def handle_stream(self, stream, address): - HTTPConnection(stream, address, self.request_callback, - self.no_keep_alive, self.xheaders, self.protocol) + context = _HTTPRequestContext(stream, address, + self.protocol) + conn = HTTP1ServerConnection( + stream, self.conn_params, context) + self._connections.add(conn) + conn.start_serving(self) + + def start_request(self, server_conn, request_conn): + return _ServerRequestAdapter(self, request_conn) + + def on_close(self, server_conn): + self._connections.remove(server_conn) -class _BadRequestException(Exception): - """Exception class for malformed HTTP requests.""" - pass - - -class HTTPConnection(object): - """Handles a connection to an HTTP client, executing HTTP requests. - - We parse HTTP headers and bodies, and execute the request callback - until the HTTP conection is closed. - """ - def __init__(self, stream, address, request_callback, no_keep_alive=False, - xheaders=False, protocol=None): - self.stream = stream +class _HTTPRequestContext(object): + def __init__(self, stream, address, protocol): self.address = address + self.protocol = protocol # Save the socket's address family now so we know how to # interpret self.address even after the stream is closed # and its socket attribute replaced with None. - self.address_family = stream.socket.family - self.request_callback = request_callback - self.no_keep_alive = no_keep_alive - self.xheaders = xheaders - self.protocol = protocol - self._clear_request_state() - # Save stack context here, outside of any request. This keeps - # contexts from one request from leaking into the next. - self._header_callback = stack_context.wrap(self._on_headers) - self.stream.set_close_callback(self._on_connection_close) - self.stream.read_until(b"\r\n\r\n", self._header_callback) - - def _clear_request_state(self): - """Clears the per-request state. - - This is run in between requests to allow the previous handler - to be garbage collected (and prevent spurious close callbacks), - and when the connection is closed (to break up cycles and - facilitate garbage collection in cpython). - """ - self._request = None - self._request_finished = False - self._write_callback = None - self._close_callback = None - - def set_close_callback(self, callback): - """Sets a callback that will be run when the connection is closed. - - Use this instead of accessing - `HTTPConnection.stream.set_close_callback - <.BaseIOStream.set_close_callback>` directly (which was the - recommended approach prior to Tornado 3.0). - """ - self._close_callback = stack_context.wrap(callback) - - def _on_connection_close(self): - if self._close_callback is not None: - callback = self._close_callback - self._close_callback = None - callback() - # Delete any unfinished callbacks to break up reference cycles. - self._header_callback = None - self._clear_request_state() - - def close(self): - self.stream.close() - # Remove this reference to self, which would otherwise cause a - # cycle and delay garbage collection of this connection. - self._header_callback = None - self._clear_request_state() - - def write(self, chunk, callback=None): - """Writes a chunk of output to the stream.""" - if not self.stream.closed(): - self._write_callback = stack_context.wrap(callback) - self.stream.write(chunk, self._on_write_complete) - - def finish(self): - """Finishes the request.""" - self._request_finished = True - # No more data is coming, so instruct TCP to send any remaining - # data immediately instead of waiting for a full packet or ack. - self.stream.set_nodelay(True) - if not self.stream.writing(): - self._finish_request() - - def _on_write_complete(self): - if self._write_callback is not None: - callback = self._write_callback - self._write_callback = None - callback() - # _on_write_complete is enqueued on the IOLoop whenever the - # IOStream's write buffer becomes empty, but it's possible for - # another callback that runs on the IOLoop before it to - # simultaneously write more data and finish the request. If - # there is still data in the IOStream, a future - # _on_write_complete will be responsible for calling - # _finish_request. - if self._request_finished and not self.stream.writing(): - self._finish_request() - - def _finish_request(self): - if self.no_keep_alive or self._request is None: - disconnect = True + if stream.socket is not None: + self.address_family = stream.socket.family else: - connection_header = self._request.headers.get("Connection") - if connection_header is not None: - connection_header = connection_header.lower() - if self._request.supports_http_1_1(): - disconnect = connection_header == "close" - elif ("Content-Length" in self._request.headers - or self._request.method in ("HEAD", "GET")): - disconnect = connection_header != "keep-alive" - else: - disconnect = True - self._clear_request_state() - if disconnect: - self.close() - return - try: - # Use a try/except instead of checking stream.closed() - # directly, because in some cases the stream doesn't discover - # that it's closed until you try to read from it. - self.stream.read_until(b"\r\n\r\n", self._header_callback) - - # Turn Nagle's algorithm back on, leaving the stream in its - # default state for the next request. - self.stream.set_nodelay(False) - except iostream.StreamClosedError: - self.close() - - def _on_headers(self, data): - try: - data = native_str(data.decode('latin1')) - eol = data.find("\r\n") - start_line = data[:eol] - try: - method, uri, version = start_line.split(" ") - except ValueError: - raise _BadRequestException("Malformed HTTP request line") - if not version.startswith("HTTP/"): - raise _BadRequestException("Malformed HTTP version in HTTP Request-Line") - try: - headers = httputil.HTTPHeaders.parse(data[eol:]) - except ValueError: - # Probably from split() if there was no ':' in the line - raise _BadRequestException("Malformed HTTP headers") - - # HTTPRequest wants an IP, not a full socket address - if self.address_family in (socket.AF_INET, socket.AF_INET6): - remote_ip = self.address[0] - else: - # Unix (or other) socket; fake the remote address - remote_ip = '0.0.0.0' - - self._request = HTTPRequest( - connection=self, method=method, uri=uri, version=version, - headers=headers, remote_ip=remote_ip, protocol=self.protocol) - - content_length = headers.get("Content-Length") - if content_length: - content_length = int(content_length) - if content_length > self.stream.max_buffer_size: - raise _BadRequestException("Content-Length too long") - if headers.get("Expect") == "100-continue": - self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") - self.stream.read_bytes(content_length, self._on_request_body) - return - - self.request_callback(self._request) - except _BadRequestException as e: - gen_log.info("Malformed HTTP request from %r: %s", - self.address, e) - self.close() - return - - def _on_request_body(self, data): - self._request.body = data - if self._request.method in ("POST", "PATCH", "PUT"): - httputil.parse_body_arguments( - self._request.headers.get("Content-Type", ""), data, - self._request.body_arguments, self._request.files) - - for k, v in self._request.body_arguments.items(): - self._request.arguments.setdefault(k, []).extend(v) - self.request_callback(self._request) - - -class HTTPRequest(object): - """A single HTTP request. - - All attributes are type `str` unless otherwise noted. - - .. attribute:: method - - HTTP request method, e.g. "GET" or "POST" - - .. attribute:: uri - - The requested uri. - - .. attribute:: path - - The path portion of `uri` - - .. attribute:: query - - The query portion of `uri` - - .. attribute:: version - - HTTP version specified in request, e.g. "HTTP/1.1" - - .. attribute:: headers - - `.HTTPHeaders` dictionary-like object for request headers. Acts like - a case-insensitive dictionary with additional methods for repeated - headers. - - .. attribute:: body - - Request body, if present, as a byte string. - - .. attribute:: remote_ip - - Client's IP address as a string. If ``HTTPServer.xheaders`` is set, - will pass along the real IP address provided by a load balancer - in the ``X-Real-Ip`` or ``X-Forwarded-For`` header. - - .. versionchanged:: 3.1 - The list format of ``X-Forwarded-For`` is now supported. - - .. attribute:: protocol - - The protocol used, either "http" or "https". If ``HTTPServer.xheaders`` - is set, will pass along the protocol used by a load balancer if - reported via an ``X-Scheme`` header. - - .. attribute:: host - - The requested hostname, usually taken from the ``Host`` header. - - .. attribute:: arguments - - GET/POST arguments are available in the arguments property, which - maps arguments names to lists of values (to support multiple values - for individual names). Names are of type `str`, while arguments - are byte strings. Note that this is different from - `.RequestHandler.get_argument`, which returns argument values as - unicode strings. - - .. attribute:: query_arguments - - Same format as ``arguments``, but contains only arguments extracted - from the query string. - - .. versionadded:: 3.2 - - .. attribute:: body_arguments - - Same format as ``arguments``, but contains only arguments extracted - from the request body. - - .. versionadded:: 3.2 - - .. attribute:: files - - File uploads are available in the files property, which maps file - names to lists of `.HTTPFile`. - - .. attribute:: connection - - An HTTP request is attached to a single HTTP connection, which can - be accessed through the "connection" attribute. Since connections - are typically kept open in HTTP/1.1, multiple requests can be handled - sequentially on a single connection. - """ - def __init__(self, method, uri, version="HTTP/1.0", headers=None, - body=None, remote_ip=None, protocol=None, host=None, - files=None, connection=None): - self.method = method - self.uri = uri - self.version = version - self.headers = headers or httputil.HTTPHeaders() - self.body = body or "" - - # set remote IP and protocol - self.remote_ip = remote_ip + self.address_family = None + # In HTTPServerRequest we want an IP, not a full socket address. + if (self.address_family in (socket.AF_INET, socket.AF_INET6) and + address is not None): + self.remote_ip = address[0] + else: + # Unix (or other) socket; fake the remote address. + self.remote_ip = '0.0.0.0' if protocol: self.protocol = protocol - elif connection and isinstance(connection.stream, - iostream.SSLIOStream): + elif isinstance(stream, iostream.SSLIOStream): self.protocol = "https" else: self.protocol = "http" + self._orig_remote_ip = self.remote_ip + self._orig_protocol = self.protocol - # xheaders can override the defaults - if connection and connection.xheaders: - # Squid uses X-Forwarded-For, others use X-Real-Ip - ip = self.headers.get("X-Forwarded-For", self.remote_ip) - ip = ip.split(',')[-1].strip() - ip = self.headers.get( - "X-Real-Ip", ip) - if netutil.is_valid_ip(ip): - self.remote_ip = ip - # AWS uses X-Forwarded-Proto - proto = self.headers.get( - "X-Scheme", self.headers.get("X-Forwarded-Proto", self.protocol)) - if proto in ("http", "https"): - self.protocol = proto + def __str__(self): + if self.address_family in (socket.AF_INET, socket.AF_INET6): + return self.remote_ip + elif isinstance(self.address, bytes): + # Python 3 with the -bb option warns about str(bytes), + # so convert it explicitly. + # Unix socket addresses are str on mac but bytes on linux. + return native_str(self.address) + else: + return str(self.address) - self.host = host or self.headers.get("Host") or "127.0.0.1" - self.files = files or {} + def _apply_xheaders(self, headers): + """Rewrite the ``remote_ip`` and ``protocol`` fields.""" + # Squid uses X-Forwarded-For, others use X-Real-Ip + ip = headers.get("X-Forwarded-For", self.remote_ip) + ip = ip.split(',')[-1].strip() + ip = headers.get("X-Real-Ip", ip) + if netutil.is_valid_ip(ip): + self.remote_ip = ip + # AWS uses X-Forwarded-Proto + proto_header = headers.get( + "X-Scheme", headers.get("X-Forwarded-Proto", + self.protocol)) + if proto_header in ("http", "https"): + self.protocol = proto_header + + def _unapply_xheaders(self): + """Undo changes from `_apply_xheaders`. + + Xheaders are per-request so they should not leak to the next + request on the same connection. + """ + self.remote_ip = self._orig_remote_ip + self.protocol = self._orig_protocol + + +class _ServerRequestAdapter(httputil.HTTPMessageDelegate): + """Adapts the `HTTPMessageDelegate` interface to the interface expected + by our clients. + """ + def __init__(self, server, connection): + self.server = server self.connection = connection - self._start_time = time.time() - self._finish_time = None + self.request = None + if isinstance(server.request_callback, + httputil.HTTPServerConnectionDelegate): + self.delegate = server.request_callback.start_request(connection) + self._chunks = None + else: + self.delegate = None + self._chunks = [] - self.path, sep, self.query = uri.partition('?') - self.arguments = parse_qs_bytes(self.query, keep_blank_values=True) - self.query_arguments = copy.deepcopy(self.arguments) - self.body_arguments = {} + def headers_received(self, start_line, headers): + if self.server.xheaders: + self.connection.context._apply_xheaders(headers) + if self.delegate is None: + self.request = httputil.HTTPServerRequest( + connection=self.connection, start_line=start_line, + headers=headers) + else: + return self.delegate.headers_received(start_line, headers) - def supports_http_1_1(self): - """Returns True if this request supports HTTP/1.1 semantics""" - return self.version == "HTTP/1.1" - - @property - def cookies(self): - """A dictionary of Cookie.Morsel objects.""" - if not hasattr(self, "_cookies"): - self._cookies = Cookie.SimpleCookie() - if "Cookie" in self.headers: - try: - self._cookies.load( - native_str(self.headers["Cookie"])) - except Exception: - self._cookies = {} - return self._cookies - - def write(self, chunk, callback=None): - """Writes the given chunk to the response stream.""" - assert isinstance(chunk, bytes_type) - self.connection.write(chunk, callback=callback) + def data_received(self, chunk): + if self.delegate is None: + self._chunks.append(chunk) + else: + return self.delegate.data_received(chunk) def finish(self): - """Finishes this HTTP request on the open connection.""" - self.connection.finish() - self._finish_time = time.time() - - def full_url(self): - """Reconstructs the full URL for this request.""" - return self.protocol + "://" + self.host + self.uri - - def request_time(self): - """Returns the amount of time it took for this request to execute.""" - if self._finish_time is None: - return time.time() - self._start_time + if self.delegate is None: + self.request.body = b''.join(self._chunks) + self.request._parse_body() + self.server.request_callback(self.request) else: - return self._finish_time - self._start_time + self.delegate.finish() + self._cleanup() - def get_ssl_certificate(self, binary_form=False): - """Returns the client's SSL certificate, if any. + def on_connection_close(self): + if self.delegate is None: + self._chunks = None + else: + self.delegate.on_connection_close() + self._cleanup() - To use client certificates, the HTTPServer must have been constructed - with cert_reqs set in ssl_options, e.g.:: + def _cleanup(self): + if self.server.xheaders: + self.connection.context._unapply_xheaders() - server = HTTPServer(app, - ssl_options=dict( - certfile="foo.crt", - keyfile="foo.key", - cert_reqs=ssl.CERT_REQUIRED, - ca_certs="cacert.crt")) - By default, the return value is a dictionary (or None, if no - client certificate is present). If ``binary_form`` is true, a - DER-encoded form of the certificate is returned instead. See - SSLSocket.getpeercert() in the standard library for more - details. - http://docs.python.org/library/ssl.html#sslsocket-objects - """ - try: - return self.connection.stream.socket.getpeercert( - binary_form=binary_form) - except ssl.SSLError: - return None - - def __repr__(self): - attrs = ("protocol", "host", "method", "uri", "version", "remote_ip") - args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs]) - return "%s(%s, headers=%s)" % ( - self.__class__.__name__, args, dict(self.headers)) +HTTPRequest = httputil.HTTPServerRequest diff --git a/tornado/httputil.py b/tornado/httputil.py index 2575bc56..a6748972 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -14,20 +14,31 @@ # License for the specific language governing permissions and limitations # under the License. -"""HTTP utility code shared by clients and servers.""" +"""HTTP utility code shared by clients and servers. + +This module also defines the `HTTPServerRequest` class which is exposed +via `tornado.web.RequestHandler.request`. +""" from __future__ import absolute_import, division, print_function, with_statement import calendar import collections +import copy import datetime import email.utils import numbers +import re import time from tornado.escape import native_str, parse_qs_bytes, utf8 from tornado.log import gen_log -from tornado.util import ObjectDict +from tornado.util import ObjectDict, bytes_type + +try: + import Cookie # py2 +except ImportError: + import http.cookies as Cookie # py3 try: from httplib import responses # py2 @@ -43,6 +54,13 @@ try: except ImportError: from urllib.parse import urlencode # py3 +try: + from ssl import SSLError +except ImportError: + # ssl is unavailable on app engine. + class SSLError(Exception): + pass + class _NormalizedHeaderCache(dict): """Dynamic cached mapping of header names to Http-Header-Case. @@ -212,6 +230,337 @@ class HTTPHeaders(dict): return HTTPHeaders(self) +class HTTPServerRequest(object): + """A single HTTP request. + + All attributes are type `str` unless otherwise noted. + + .. attribute:: method + + HTTP request method, e.g. "GET" or "POST" + + .. attribute:: uri + + The requested uri. + + .. attribute:: path + + The path portion of `uri` + + .. attribute:: query + + The query portion of `uri` + + .. attribute:: version + + HTTP version specified in request, e.g. "HTTP/1.1" + + .. attribute:: headers + + `.HTTPHeaders` dictionary-like object for request headers. Acts like + a case-insensitive dictionary with additional methods for repeated + headers. + + .. attribute:: body + + Request body, if present, as a byte string. + + .. attribute:: remote_ip + + Client's IP address as a string. If ``HTTPServer.xheaders`` is set, + will pass along the real IP address provided by a load balancer + in the ``X-Real-Ip`` or ``X-Forwarded-For`` header. + + .. versionchanged:: 3.1 + The list format of ``X-Forwarded-For`` is now supported. + + .. attribute:: protocol + + The protocol used, either "http" or "https". If ``HTTPServer.xheaders`` + is set, will pass along the protocol used by a load balancer if + reported via an ``X-Scheme`` header. + + .. attribute:: host + + The requested hostname, usually taken from the ``Host`` header. + + .. attribute:: arguments + + GET/POST arguments are available in the arguments property, which + maps arguments names to lists of values (to support multiple values + for individual names). Names are of type `str`, while arguments + are byte strings. Note that this is different from + `.RequestHandler.get_argument`, which returns argument values as + unicode strings. + + .. attribute:: query_arguments + + Same format as ``arguments``, but contains only arguments extracted + from the query string. + + .. versionadded:: 3.2 + + .. attribute:: body_arguments + + Same format as ``arguments``, but contains only arguments extracted + from the request body. + + .. versionadded:: 3.2 + + .. attribute:: files + + File uploads are available in the files property, which maps file + names to lists of `.HTTPFile`. + + .. attribute:: connection + + An HTTP request is attached to a single HTTP connection, which can + be accessed through the "connection" attribute. Since connections + are typically kept open in HTTP/1.1, multiple requests can be handled + sequentially on a single connection. + + .. versionchanged:: 4.0 + Moved from ``tornado.httpserver.HTTPRequest``. + """ + def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None, + body=None, host=None, files=None, connection=None, + start_line=None): + if start_line is not None: + method, uri, version = start_line + self.method = method + self.uri = uri + self.version = version + self.headers = headers or HTTPHeaders() + self.body = body or "" + + # set remote IP and protocol + context = getattr(connection, 'context', None) + self.remote_ip = getattr(context, 'remote_ip') + self.protocol = getattr(context, 'protocol', "http") + + self.host = host or self.headers.get("Host") or "127.0.0.1" + self.files = files or {} + self.connection = connection + self._start_time = time.time() + self._finish_time = None + + self.path, sep, self.query = uri.partition('?') + self.arguments = parse_qs_bytes(self.query, keep_blank_values=True) + self.query_arguments = copy.deepcopy(self.arguments) + self.body_arguments = {} + + def supports_http_1_1(self): + """Returns True if this request supports HTTP/1.1 semantics. + + .. deprecated:: 4.0 + Applications are less likely to need this information with the + introduction of `.HTTPConnection`. If you still need it, access + the ``version`` attribute directly. + """ + return self.version == "HTTP/1.1" + + @property + def cookies(self): + """A dictionary of Cookie.Morsel objects.""" + if not hasattr(self, "_cookies"): + self._cookies = Cookie.SimpleCookie() + if "Cookie" in self.headers: + try: + self._cookies.load( + native_str(self.headers["Cookie"])) + except Exception: + self._cookies = {} + return self._cookies + + def write(self, chunk, callback=None): + """Writes the given chunk to the response stream. + + .. deprecated:: 4.0 + Use ``request.connection`` and the `.HTTPConnection` methods + to write the response. + """ + assert isinstance(chunk, bytes_type) + self.connection.write(chunk, callback=callback) + + def finish(self): + """Finishes this HTTP request on the open connection. + + .. deprecated:: 4.0 + Use ``request.connection`` and the `.HTTPConnection` methods + to write the response. + """ + self.connection.finish() + self._finish_time = time.time() + + def full_url(self): + """Reconstructs the full URL for this request.""" + return self.protocol + "://" + self.host + self.uri + + def request_time(self): + """Returns the amount of time it took for this request to execute.""" + if self._finish_time is None: + return time.time() - self._start_time + else: + return self._finish_time - self._start_time + + def get_ssl_certificate(self, binary_form=False): + """Returns the client's SSL certificate, if any. + + To use client certificates, the HTTPServer must have been constructed + with cert_reqs set in ssl_options, e.g.:: + + server = HTTPServer(app, + ssl_options=dict( + certfile="foo.crt", + keyfile="foo.key", + cert_reqs=ssl.CERT_REQUIRED, + ca_certs="cacert.crt")) + + By default, the return value is a dictionary (or None, if no + client certificate is present). If ``binary_form`` is true, a + DER-encoded form of the certificate is returned instead. See + SSLSocket.getpeercert() in the standard library for more + details. + http://docs.python.org/library/ssl.html#sslsocket-objects + """ + try: + return self.connection.stream.socket.getpeercert( + binary_form=binary_form) + except SSLError: + return None + + def _parse_body(self): + parse_body_arguments( + self.headers.get("Content-Type", ""), self.body, + self.body_arguments, self.files, + self.headers) + + for k, v in self.body_arguments.items(): + self.arguments.setdefault(k, []).extend(v) + + def __repr__(self): + attrs = ("protocol", "host", "method", "uri", "version", "remote_ip") + args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs]) + return "%s(%s, headers=%s)" % ( + self.__class__.__name__, args, dict(self.headers)) + + +class HTTPInputError(Exception): + """Exception class for malformed HTTP requests or responses + from remote sources. + + .. versionadded:: 4.0 + """ + pass + + +class HTTPOutputError(Exception): + """Exception class for errors in HTTP output. + + .. versionadded:: 4.0 + """ + pass + + +class HTTPServerConnectionDelegate(object): + """Implement this interface to handle requests from `.HTTPServer`. + + .. versionadded:: 4.0 + """ + def start_request(self, server_conn, request_conn): + """This method is called by the server when a new request has started. + + :arg server_conn: is an opaque object representing the long-lived + (e.g. tcp-level) connection. + :arg request_conn: is a `.HTTPConnection` object for a single + request/response exchange. + + This method should return a `.HTTPMessageDelegate`. + """ + raise NotImplementedError() + + def on_close(self, server_conn): + """This method is called when a connection has been closed. + + :arg server_conn: is a server connection that has previously been + passed to ``start_request``. + """ + pass + + +class HTTPMessageDelegate(object): + """Implement this interface to handle an HTTP request or response. + + .. versionadded:: 4.0 + """ + def headers_received(self, start_line, headers): + """Called when the HTTP headers have been received and parsed. + + :arg start_line: a `.RequestStartLine` or `.ResponseStartLine` + depending on whether this is a client or server message. + :arg headers: a `.HTTPHeaders` instance. + + Some `.HTTPConnection` methods can only be called during + ``headers_received``. + + May return a `.Future`; if it does the body will not be read + until it is done. + """ + pass + + def data_received(self, chunk): + """Called when a chunk of data has been received. + + May return a `.Future` for flow control. + """ + pass + + def finish(self): + """Called after the last chunk of data has been received.""" + pass + + def on_connection_close(self): + """Called if the connection is closed without finishing the request. + + If ``headers_received`` is called, either ``finish`` or + ``on_connection_close`` will be called, but not both. + """ + pass + + +class HTTPConnection(object): + """Applications use this interface to write their responses. + + .. versionadded:: 4.0 + """ + def write_headers(self, start_line, headers, chunk=None, callback=None): + """Write an HTTP header block. + + :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`. + :arg headers: a `.HTTPHeaders` instance. + :arg chunk: the first (optional) chunk of data. This is an optimization + so that small responses can be written in the same call as their + headers. + :arg callback: a callback to be run when the write is complete. + + Returns a `.Future` if no callback is given. + """ + raise NotImplementedError() + + def write(self, chunk, callback=None): + """Writes a chunk of body data. + + The callback will be run when the write is complete. If no callback + is given, returns a Future. + """ + raise NotImplementedError() + + def finish(self): + """Indicates that the last body data has been written. + """ + raise NotImplementedError() + + def url_concat(url, args): """Concatenate url and argument dictionary regardless of whether url has existing query parameters. @@ -310,7 +659,7 @@ def _int_or_none(val): return int(val) -def parse_body_arguments(content_type, body, arguments, files): +def parse_body_arguments(content_type, body, arguments, files, headers=None): """Parses a form request body. Supports ``application/x-www-form-urlencoded`` and @@ -319,6 +668,10 @@ def parse_body_arguments(content_type, body, arguments, files): and ``files`` parameters are dictionaries that will be updated with the parsed contents. """ + if headers and 'Content-Encoding' in headers: + gen_log.warning("Unsupported Content-Encoding: %s", + headers['Content-Encoding']) + return if content_type.startswith("application/x-www-form-urlencoded"): try: uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True) @@ -405,6 +758,48 @@ def format_timestamp(ts): raise TypeError("unknown timestamp type: %r" % ts) return email.utils.formatdate(ts, usegmt=True) + +RequestStartLine = collections.namedtuple( + 'RequestStartLine', ['method', 'path', 'version']) + + +def parse_request_start_line(line): + """Returns a (method, path, version) tuple for an HTTP 1.x request line. + + The response is a `collections.namedtuple`. + + >>> parse_request_start_line("GET /foo HTTP/1.1") + RequestStartLine(method='GET', path='/foo', version='HTTP/1.1') + """ + try: + method, path, version = line.split(" ") + except ValueError: + raise HTTPInputError("Malformed HTTP request line") + if not version.startswith("HTTP/"): + raise HTTPInputError( + "Malformed HTTP version in HTTP Request-Line: %r" % version) + return RequestStartLine(method, path, version) + + +ResponseStartLine = collections.namedtuple( + 'ResponseStartLine', ['version', 'code', 'reason']) + + +def parse_response_start_line(line): + """Returns a (version, code, reason) tuple for an HTTP 1.x response line. + + The response is a `collections.namedtuple`. + + >>> parse_response_start_line("HTTP/1.1 200 OK") + ResponseStartLine(version='HTTP/1.1', code=200, reason='OK') + """ + line = native_str(line) + match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line) + if not match: + raise HTTPInputError("Error parsing response start line") + return ResponseStartLine(match.group(1), int(match.group(2)), + match.group(3)) + # _parseparam and _parse_header are copied and modified from python2.7's cgi.py # The original 2.7 version of this code did not correctly support some # combinations of semicolons and double quotes. diff --git a/tornado/ioloop.py b/tornado/ioloop.py index e7b84dd7..3477684c 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -32,6 +32,7 @@ import datetime import errno import functools import heapq +import itertools import logging import numbers import os @@ -41,10 +42,11 @@ import threading import time import traceback -from tornado.concurrent import Future, TracebackFuture +from tornado.concurrent import TracebackFuture, is_future from tornado.log import app_log, gen_log from tornado import stack_context from tornado.util import Configurable +from tornado.util import errno_from_exception try: import signal @@ -156,6 +158,15 @@ class IOLoop(Configurable): assert not IOLoop.initialized() IOLoop._instance = self + @staticmethod + def clear_instance(): + """Clear the global `IOLoop` instance. + + .. versionadded:: 4.0 + """ + if hasattr(IOLoop, "_instance"): + del IOLoop._instance + @staticmethod def current(): """Returns the current thread's `IOLoop`. @@ -244,21 +255,40 @@ class IOLoop(Configurable): raise NotImplementedError() def add_handler(self, fd, handler, events): - """Registers the given handler to receive the given events for fd. + """Registers the given handler to receive the given events for ``fd``. + + The ``fd`` argument may either be an integer file descriptor or + a file-like object with a ``fileno()`` method (and optionally a + ``close()`` method, which may be called when the `IOLoop` is shut + down). The ``events`` argument is a bitwise or of the constants ``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``. When an event occurs, ``handler(fd, events)`` will be run. + + .. versionchanged:: 4.0 + Added the ability to pass file-like objects in addition to + raw file descriptors. """ raise NotImplementedError() def update_handler(self, fd, events): - """Changes the events we listen for fd.""" + """Changes the events we listen for ``fd``. + + .. versionchanged:: 4.0 + Added the ability to pass file-like objects in addition to + raw file descriptors. + """ raise NotImplementedError() def remove_handler(self, fd): - """Stop listening for events on fd.""" + """Stop listening for events on ``fd``. + + .. versionchanged:: 4.0 + Added the ability to pass file-like objects in addition to + raw file descriptors. + """ raise NotImplementedError() def set_blocking_signal_threshold(self, seconds, action): @@ -372,7 +402,7 @@ class IOLoop(Configurable): future_cell[0] = TracebackFuture() future_cell[0].set_exc_info(sys.exc_info()) else: - if isinstance(result, Future): + if is_future(result): future_cell[0] = result else: future_cell[0] = TracebackFuture() @@ -456,6 +486,19 @@ class IOLoop(Configurable): """ raise NotImplementedError() + def spawn_callback(self, callback, *args, **kwargs): + """Calls the given callback on the next IOLoop iteration. + + Unlike all other callback-related methods on IOLoop, + ``spawn_callback`` does not associate the callback with its caller's + ``stack_context``, so it is suitable for fire-and-forget callbacks + that should not interfere with the caller. + + .. versionadded:: 4.0 + """ + with stack_context.NullContext(): + self.add_callback(callback, *args, **kwargs) + def add_future(self, future, callback): """Schedules a callback on the ``IOLoop`` when the given `.Future` is finished. @@ -463,7 +506,7 @@ class IOLoop(Configurable): The callback is invoked with one argument, the `.Future`. """ - assert isinstance(future, Future) + assert is_future(future) callback = stack_context.wrap(callback) future.add_done_callback( lambda future: self.add_callback(callback, future)) @@ -474,7 +517,13 @@ class IOLoop(Configurable): For use in subclasses. """ try: - callback() + ret = callback() + if ret is not None and is_future(ret): + # Functions that return Futures typically swallow all + # exceptions and store them in the Future. If a Future + # makes it out to the IOLoop, ensure its exception (if any) + # gets logged too. + self.add_future(ret, lambda f: f.result()) except Exception: self.handle_callback_exception(callback) @@ -490,6 +539,47 @@ class IOLoop(Configurable): """ app_log.error("Exception in callback %r", callback, exc_info=True) + def split_fd(self, fd): + """Returns an (fd, obj) pair from an ``fd`` parameter. + + We accept both raw file descriptors and file-like objects as + input to `add_handler` and related methods. When a file-like + object is passed, we must retain the object itself so we can + close it correctly when the `IOLoop` shuts down, but the + poller interfaces favor file descriptors (they will accept + file-like objects and call ``fileno()`` for you, but they + always return the descriptor itself). + + This method is provided for use by `IOLoop` subclasses and should + not generally be used by application code. + + .. versionadded:: 4.0 + """ + try: + return fd.fileno(), fd + except AttributeError: + return fd, fd + + def close_fd(self, fd): + """Utility method to close an ``fd``. + + If ``fd`` is a file-like object, we close it directly; otherwise + we use `os.close`. + + This method is provided for use by `IOLoop` subclasses (in + implementations of ``IOLoop.close(all_fds=True)`` and should + not generally be used by application code. + + .. versionadded:: 4.0 + """ + try: + try: + fd.close() + except AttributeError: + os.close(fd) + except OSError: + pass + class PollIOLoop(IOLoop): """Base class for IOLoops built around a select-like function. @@ -515,7 +605,8 @@ class PollIOLoop(IOLoop): self._closing = False self._thread_ident = None self._blocking_signal_threshold = None - + self._timeout_counter = itertools.count() + # Create a pipe that we send bogus data to when we want to wake # the I/O loop when it is idle self._waker = Waker() @@ -528,26 +619,24 @@ class PollIOLoop(IOLoop): self._closing = True self.remove_handler(self._waker.fileno()) if all_fds: - for fd in self._handlers.keys(): - try: - close_method = getattr(fd, 'close', None) - if close_method is not None: - close_method() - else: - os.close(fd) - except Exception: - gen_log.debug("error closing fd %s", fd, exc_info=True) + for fd, handler in self._handlers.values(): + self.close_fd(fd) self._waker.close() self._impl.close() + self._callbacks = None + self._timeouts = None def add_handler(self, fd, handler, events): - self._handlers[fd] = stack_context.wrap(handler) + fd, obj = self.split_fd(fd) + self._handlers[fd] = (obj, stack_context.wrap(handler)) self._impl.register(fd, events | self.ERROR) def update_handler(self, fd, events): + fd, obj = self.split_fd(fd) self._impl.modify(fd, events | self.ERROR) def remove_handler(self, fd): + fd, obj = self.split_fd(fd) self._handlers.pop(fd, None) self._events.pop(fd, None) try: @@ -566,6 +655,8 @@ class PollIOLoop(IOLoop): action if action is not None else signal.SIG_DFL) def start(self): + if self._running: + raise RuntimeError("IOLoop is already running") self._setup_logging() if self._stopped: self._stopped = False @@ -608,19 +699,16 @@ class PollIOLoop(IOLoop): try: while True: - poll_timeout = _POLL_TIMEOUT - # Prevent IO event starvation by delaying new callbacks # to the next iteration of the event loop. with self._callback_lock: callbacks = self._callbacks self._callbacks = [] - for callback in callbacks: - self._run_callback(callback) - # Closures may be holding on to a lot of memory, so allow - # them to be freed before we go into our poll wait. - callbacks = callback = None + # Add any timeouts that have come due to the callback list. + # Do not run anything until we have determined which ones + # are ready, so timeouts that call add_timeout cannot + # schedule anything in this iteration. if self._timeouts: now = self.time() while self._timeouts: @@ -630,11 +718,9 @@ class PollIOLoop(IOLoop): self._cancellations -= 1 elif self._timeouts[0].deadline <= now: timeout = heapq.heappop(self._timeouts) - self._run_callback(timeout.callback) + callbacks.append(timeout.callback) del timeout else: - seconds = self._timeouts[0].deadline - now - poll_timeout = min(seconds, poll_timeout) break if (self._cancellations > 512 and self._cancellations > (len(self._timeouts) >> 1)): @@ -645,10 +731,25 @@ class PollIOLoop(IOLoop): if x.callback is not None] heapq.heapify(self._timeouts) + for callback in callbacks: + self._run_callback(callback) + # Closures may be holding on to a lot of memory, so allow + # them to be freed before we go into our poll wait. + callbacks = callback = None + if self._callbacks: # If any callbacks or timeouts called add_callback, # we don't want to wait in poll() before we run them. poll_timeout = 0.0 + elif self._timeouts: + # If there are any timeouts, schedule the first one. + # Use self.time() instead of 'now' to account for time + # spent running callbacks. + poll_timeout = self._timeouts[0].deadline - self.time() + poll_timeout = max(0, min(poll_timeout, _POLL_TIMEOUT)) + else: + # No timeouts and no callbacks, so use the default. + poll_timeout = _POLL_TIMEOUT if not self._running: break @@ -666,9 +767,7 @@ class PollIOLoop(IOLoop): # two ways EINTR might be signaled: # * e.errno == errno.EINTR # * e.args is like (errno.EINTR, 'Interrupted system call') - if (getattr(e, 'errno', None) == errno.EINTR or - (isinstance(getattr(e, 'args', None), tuple) and - len(e.args) == 2 and e.args[0] == errno.EINTR)): + if errno_from_exception(e) == errno.EINTR: continue else: raise @@ -685,15 +784,17 @@ class PollIOLoop(IOLoop): while self._events: fd, events = self._events.popitem() try: - self._handlers[fd](fd, events) + fd_obj, handler_func = self._handlers[fd] + handler_func(fd_obj, events) except (OSError, IOError) as e: - if e.args[0] == errno.EPIPE: + if errno_from_exception(e) == errno.EPIPE: # Happens when the client closes the connection pass else: self.handle_callback_exception(self._handlers.get(fd)) except Exception: self.handle_callback_exception(self._handlers.get(fd)) + fd_obj = handler_func = None finally: # reset the stopped flag so another start/stop pair can be issued @@ -765,16 +866,21 @@ class _Timeout(object): """An IOLoop timeout, a UNIX timestamp and a callback""" # Reduce memory overhead when there are lots of pending callbacks - __slots__ = ['deadline', 'callback'] + __slots__ = ['deadline', 'callback', 'tiebreaker'] def __init__(self, deadline, callback, io_loop): if isinstance(deadline, numbers.Real): self.deadline = deadline elif isinstance(deadline, datetime.timedelta): - self.deadline = io_loop.time() + _Timeout.timedelta_to_seconds(deadline) + now = io_loop.time() + try: + self.deadline = now + deadline.total_seconds() + except AttributeError: # py2.6 + self.deadline = now + _Timeout.timedelta_to_seconds(deadline) else: raise TypeError("Unsupported deadline %r" % deadline) self.callback = callback + self.tiebreaker = next(io_loop._timeout_counter) @staticmethod def timedelta_to_seconds(td): @@ -786,12 +892,12 @@ class _Timeout(object): # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons # use __lt__). def __lt__(self, other): - return ((self.deadline, id(self)) < - (other.deadline, id(other))) + return ((self.deadline, self.tiebreaker) < + (other.deadline, other.tiebreaker)) def __le__(self, other): - return ((self.deadline, id(self)) <= - (other.deadline, id(other))) + return ((self.deadline, self.tiebreaker) <= + (other.deadline, other.tiebreaker)) class PeriodicCallback(object): diff --git a/tornado/iostream.py b/tornado/iostream.py index 5d4d08ac..8b614258 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -31,21 +31,27 @@ import errno import numbers import os import socket -import ssl import sys import re +from tornado.concurrent import TracebackFuture from tornado import ioloop from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError from tornado import stack_context -from tornado.util import bytes_type +from tornado.util import bytes_type, errno_from_exception try: from tornado.platform.posix import _set_nonblocking except ImportError: _set_nonblocking = None +try: + import ssl +except ImportError: + # ssl is not available on Google App Engine + ssl = None + # These errnos indicate that a non-blocking operation must be retried # at a later time. On most platforms they're the same value, but on # some they differ. @@ -53,7 +59,8 @@ _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) # These errnos indicate that a connection has been abruptly terminated. # They should be caught and handled less noisily than other errors. -_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE) +_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, + errno.ETIMEDOUT) class StreamClosedError(IOError): @@ -66,12 +73,31 @@ class StreamClosedError(IOError): pass +class UnsatisfiableReadError(Exception): + """Exception raised when a read cannot be satisfied. + + Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes`` + argument. + """ + pass + + +class StreamBufferFullError(Exception): + """Exception raised by `IOStream` methods when the buffer is full. + """ + + class BaseIOStream(object): """A utility class to write to and read from a non-blocking file or socket. We support a non-blocking ``write()`` and a family of ``read_*()`` methods. - All of the methods take callbacks (since writing and reading are - non-blocking and asynchronous). + All of the methods take an optional ``callback`` argument and return a + `.Future` only if no callback is given. When the operation completes, + the callback will be run or the `.Future` will resolve with the data + read (or ``None`` for ``write()``). All outstanding ``Futures`` will + resolve with a `StreamClosedError` when the stream is closed; users + of the callback interface will be notified via + `.BaseIOStream.set_close_callback` instead. When a stream is closed due to an error, the IOStream's ``error`` attribute contains the exception object. @@ -80,24 +106,48 @@ class BaseIOStream(object): `read_from_fd`, and optionally `get_fd_error`. """ def __init__(self, io_loop=None, max_buffer_size=None, - read_chunk_size=4096): + read_chunk_size=None, max_write_buffer_size=None): + """`BaseIOStream` constructor. + + :arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`. + :arg max_buffer_size: Maximum amount of incoming data to buffer; + defaults to 100MB. + :arg read_chunk_size: Amount of data to read at one time from the + underlying transport; defaults to 64KB. + :arg max_write_buffer_size: Amount of outgoing data to buffer; + defaults to unlimited. + + .. versionchanged:: 4.0 + Add the ``max_write_buffer_size`` parameter. Changed default + ``read_chunk_size`` to 64KB. + """ self.io_loop = io_loop or ioloop.IOLoop.current() self.max_buffer_size = max_buffer_size or 104857600 - self.read_chunk_size = read_chunk_size + # A chunk size that is too close to max_buffer_size can cause + # spurious failures. + self.read_chunk_size = min(read_chunk_size or 65536, + self.max_buffer_size // 2) + self.max_write_buffer_size = max_write_buffer_size self.error = None self._read_buffer = collections.deque() self._write_buffer = collections.deque() self._read_buffer_size = 0 + self._write_buffer_size = 0 self._write_buffer_frozen = False self._read_delimiter = None self._read_regex = None + self._read_max_bytes = None self._read_bytes = None + self._read_partial = False self._read_until_close = False self._read_callback = None + self._read_future = None self._streaming_callback = None self._write_callback = None + self._write_future = None self._close_callback = None self._connect_callback = None + self._connect_future = None self._connecting = False self._state = None self._pending_callbacks = 0 @@ -142,98 +192,162 @@ class BaseIOStream(object): """ return None - def read_until_regex(self, regex, callback): - """Run ``callback`` when we read the given regex pattern. + def read_until_regex(self, regex, callback=None, max_bytes=None): + """Asynchronously read until we have matched the given regex. - The callback will get the data read (including the data that - matched the regex and anything that came before it) as an argument. + The result includes the data that matches the regex and anything + that came before it. If a callback is given, it will be run + with the data as an argument; if not, this method returns a + `.Future`. + + If ``max_bytes`` is not None, the connection will be closed + if more than ``max_bytes`` bytes have been read and the regex is + not satisfied. + + .. versionchanged:: 4.0 + Added the ``max_bytes`` argument. The ``callback`` argument is + now optional and a `.Future` will be returned if it is omitted. """ - self._set_read_callback(callback) + future = self._set_read_callback(callback) self._read_regex = re.compile(regex) - self._try_inline_read() + self._read_max_bytes = max_bytes + try: + self._try_inline_read() + except UnsatisfiableReadError as e: + # Handle this the same way as in _handle_events. + gen_log.info("Unsatisfiable read, closing connection: %s" % e) + self.close(exc_info=True) + return future + return future - def read_until(self, delimiter, callback): - """Run ``callback`` when we read the given delimiter. + def read_until(self, delimiter, callback=None, max_bytes=None): + """Asynchronously read until we have found the given delimiter. - The callback will get the data read (including the delimiter) - as an argument. + The result includes all the data read including the delimiter. + If a callback is given, it will be run with the data as an argument; + if not, this method returns a `.Future`. + + If ``max_bytes`` is not None, the connection will be closed + if more than ``max_bytes`` bytes have been read and the delimiter + is not found. + + .. versionchanged:: 4.0 + Added the ``max_bytes`` argument. The ``callback`` argument is + now optional and a `.Future` will be returned if it is omitted. """ - self._set_read_callback(callback) + future = self._set_read_callback(callback) self._read_delimiter = delimiter - self._try_inline_read() + self._read_max_bytes = max_bytes + try: + self._try_inline_read() + except UnsatisfiableReadError as e: + # Handle this the same way as in _handle_events. + gen_log.info("Unsatisfiable read, closing connection: %s" % e) + self.close(exc_info=True) + return future + return future - def read_bytes(self, num_bytes, callback, streaming_callback=None): - """Run callback when we read the given number of bytes. + def read_bytes(self, num_bytes, callback=None, streaming_callback=None, + partial=False): + """Asynchronously read a number of bytes. If a ``streaming_callback`` is given, it will be called with chunks - of data as they become available, and the argument to the final - ``callback`` will be empty. Otherwise, the ``callback`` gets - the data as an argument. + of data as they become available, and the final result will be empty. + Otherwise, the result is all the data that was read. + If a callback is given, it will be run with the data as an argument; + if not, this method returns a `.Future`. + + If ``partial`` is true, the callback is run as soon as we have + any bytes to return (but never more than ``num_bytes``) + + .. versionchanged:: 4.0 + Added the ``partial`` argument. The callback argument is now + optional and a `.Future` will be returned if it is omitted. """ - self._set_read_callback(callback) + future = self._set_read_callback(callback) assert isinstance(num_bytes, numbers.Integral) self._read_bytes = num_bytes + self._read_partial = partial self._streaming_callback = stack_context.wrap(streaming_callback) self._try_inline_read() + return future - def read_until_close(self, callback, streaming_callback=None): - """Reads all data from the socket until it is closed. + def read_until_close(self, callback=None, streaming_callback=None): + """Asynchronously reads all data from the socket until it is closed. If a ``streaming_callback`` is given, it will be called with chunks - of data as they become available, and the argument to the final - ``callback`` will be empty. Otherwise, the ``callback`` gets the - data as an argument. + of data as they become available, and the final result will be empty. + Otherwise, the result is all the data that was read. + If a callback is given, it will be run with the data as an argument; + if not, this method returns a `.Future`. - Subject to ``max_buffer_size`` limit from `IOStream` constructor if - a ``streaming_callback`` is not used. + .. versionchanged:: 4.0 + The callback argument is now optional and a `.Future` will + be returned if it is omitted. """ - self._set_read_callback(callback) + future = self._set_read_callback(callback) self._streaming_callback = stack_context.wrap(streaming_callback) if self.closed(): if self._streaming_callback is not None: - self._run_callback(self._streaming_callback, - self._consume(self._read_buffer_size)) - self._run_callback(self._read_callback, - self._consume(self._read_buffer_size)) - self._streaming_callback = None - self._read_callback = None - return + self._run_read_callback(self._read_buffer_size, True) + self._run_read_callback(self._read_buffer_size, False) + return future self._read_until_close = True - self._streaming_callback = stack_context.wrap(streaming_callback) self._try_inline_read() + return future def write(self, data, callback=None): - """Write the given data to this stream. + """Asynchronously write the given data to this stream. If ``callback`` is given, we call it when all of the buffered write data has been successfully written to the stream. If there was previously buffered write data and an old write callback, that callback is simply overwritten with this new callback. + + If no ``callback`` is given, this method returns a `.Future` that + resolves (with a result of ``None``) when the write has been + completed. If `write` is called again before that `.Future` has + resolved, the previous future will be orphaned and will never resolve. + + .. versionchanged:: 4.0 + Now returns a `.Future` if no callback is given. """ assert isinstance(data, bytes_type) self._check_closed() # We use bool(_write_buffer) as a proxy for write_buffer_size>0, # so never put empty strings in the buffer. if data: + if (self.max_write_buffer_size is not None and + self._write_buffer_size + len(data) > self.max_write_buffer_size): + raise StreamBufferFullError("Reached maximum read buffer size") # Break up large contiguous strings before inserting them in the # write buffer, so we don't have to recopy the entire thing # as we slice off pieces to send to the socket. WRITE_BUFFER_CHUNK_SIZE = 128 * 1024 - if len(data) > WRITE_BUFFER_CHUNK_SIZE: - for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE): - self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE]) - else: - self._write_buffer.append(data) - self._write_callback = stack_context.wrap(callback) + for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE): + self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE]) + self._write_buffer_size += len(data) + if callback is not None: + self._write_callback = stack_context.wrap(callback) + future = None + else: + future = self._write_future = TracebackFuture() if not self._connecting: self._handle_write() if self._write_buffer: self._add_io_state(self.io_loop.WRITE) self._maybe_add_error_listener() + return future def set_close_callback(self, callback): - """Call the given callback when the stream is closed.""" + """Call the given callback when the stream is closed. + + This is not necessary for applications that use the `.Future` + interface; all outstanding ``Futures`` will resolve with a + `StreamClosedError` when the stream is closed. + """ self._close_callback = stack_context.wrap(callback) + self._maybe_add_error_listener() def close(self, exc_info=False): """Close this stream. @@ -251,13 +365,9 @@ class BaseIOStream(object): if self._read_until_close: if (self._streaming_callback is not None and self._read_buffer_size): - self._run_callback(self._streaming_callback, - self._consume(self._read_buffer_size)) - callback = self._read_callback - self._read_callback = None + self._run_read_callback(self._read_buffer_size, True) self._read_until_close = False - self._run_callback(callback, - self._consume(self._read_buffer_size)) + self._run_read_callback(self._read_buffer_size, False) if self._state is not None: self.io_loop.remove_handler(self.fileno()) self._state = None @@ -269,6 +379,25 @@ class BaseIOStream(object): # If there are pending callbacks, don't run the close callback # until they're done (see _maybe_add_error_handler) if self.closed() and self._pending_callbacks == 0: + futures = [] + if self._read_future is not None: + futures.append(self._read_future) + self._read_future = None + if self._write_future is not None: + futures.append(self._write_future) + self._write_future = None + if self._connect_future is not None: + futures.append(self._connect_future) + self._connect_future = None + for future in futures: + if (isinstance(self.error, (socket.error, IOError)) and + errno_from_exception(self.error) in _ERRNO_CONNRESET): + # Treat connection resets as closed connections so + # clients only have to catch one kind of exception + # to avoid logging. + future.set_exception(StreamClosedError()) + else: + future.set_exception(self.error or StreamClosedError()) if self._close_callback is not None: cb = self._close_callback self._close_callback = None @@ -282,7 +411,7 @@ class BaseIOStream(object): def reading(self): """Returns true if we are currently reading from the stream.""" - return self._read_callback is not None + return self._read_callback is not None or self._read_future is not None def writing(self): """Returns true if we are currently writing to the stream.""" @@ -309,16 +438,22 @@ class BaseIOStream(object): def _handle_events(self, fd, events): if self.closed(): - gen_log.warning("Got events for closed stream %d", fd) + gen_log.warning("Got events for closed stream %s", fd) return try: + if self._connecting: + # Most IOLoops will report a write failed connect + # with the WRITE event, but SelectIOLoop reports a + # READ as well so we must check for connecting before + # either. + self._handle_connect() + if self.closed(): + return if events & self.io_loop.READ: self._handle_read() if self.closed(): return if events & self.io_loop.WRITE: - if self._connecting: - self._handle_connect() self._handle_write() if self.closed(): return @@ -334,13 +469,20 @@ class BaseIOStream(object): state |= self.io_loop.READ if self.writing(): state |= self.io_loop.WRITE - if state == self.io_loop.ERROR: + if state == self.io_loop.ERROR and self._read_buffer_size == 0: + # If the connection is idle, listen for reads too so + # we can tell if the connection is closed. If there is + # data in the read buffer we won't run the close callback + # yet anyway, so we don't need to listen in this case. state |= self.io_loop.READ if state != self._state: assert self._state is not None, \ "shouldn't happen: _handle_events without self._state" self._state = state self.io_loop.update_handler(self.fileno(), self._state) + except UnsatisfiableReadError as e: + gen_log.info("Unsatisfiable read, closing connection: %s" % e) + self.close(exc_info=True) except Exception: gen_log.error("Uncaught exception, closing connection.", exc_info=True) @@ -381,42 +523,108 @@ class BaseIOStream(object): self._pending_callbacks += 1 self.io_loop.add_callback(wrapper) + def _read_to_buffer_loop(self): + # This method is called from _handle_read and _try_inline_read. + try: + if self._read_bytes is not None: + target_bytes = self._read_bytes + elif self._read_max_bytes is not None: + target_bytes = self._read_max_bytes + elif self.reading(): + # For read_until without max_bytes, or + # read_until_close, read as much as we can before + # scanning for the delimiter. + target_bytes = None + else: + target_bytes = 0 + next_find_pos = 0 + # Pretend to have a pending callback so that an EOF in + # _read_to_buffer doesn't trigger an immediate close + # callback. At the end of this method we'll either + # estabilsh a real pending callback via + # _read_from_buffer or run the close callback. + # + # We need two try statements here so that + # pending_callbacks is decremented before the `except` + # clause below (which calls `close` and does need to + # trigger the callback) + self._pending_callbacks += 1 + while not self.closed(): + # Read from the socket until we get EWOULDBLOCK or equivalent. + # SSL sockets do some internal buffering, and if the data is + # sitting in the SSL object's buffer select() and friends + # can't see it; the only way to find out if it's there is to + # try to read it. + if self._read_to_buffer() == 0: + break + + self._run_streaming_callback() + + # If we've read all the bytes we can use, break out of + # this loop. We can't just call read_from_buffer here + # because of subtle interactions with the + # pending_callback and error_listener mechanisms. + # + # If we've reached target_bytes, we know we're done. + if (target_bytes is not None and + self._read_buffer_size >= target_bytes): + break + + # Otherwise, we need to call the more expensive find_read_pos. + # It's inefficient to do this on every read, so instead + # do it on the first read and whenever the read buffer + # size has doubled. + if self._read_buffer_size >= next_find_pos: + pos = self._find_read_pos() + if pos is not None: + return pos + next_find_pos = self._read_buffer_size * 2 + return self._find_read_pos() + finally: + self._pending_callbacks -= 1 + def _handle_read(self): try: - try: - # Pretend to have a pending callback so that an EOF in - # _read_to_buffer doesn't trigger an immediate close - # callback. At the end of this method we'll either - # estabilsh a real pending callback via - # _read_from_buffer or run the close callback. - # - # We need two try statements here so that - # pending_callbacks is decremented before the `except` - # clause below (which calls `close` and does need to - # trigger the callback) - self._pending_callbacks += 1 - while not self.closed(): - # Read from the socket until we get EWOULDBLOCK or equivalent. - # SSL sockets do some internal buffering, and if the data is - # sitting in the SSL object's buffer select() and friends - # can't see it; the only way to find out if it's there is to - # try to read it. - if self._read_to_buffer() == 0: - break - finally: - self._pending_callbacks -= 1 + pos = self._read_to_buffer_loop() + except UnsatisfiableReadError: + raise except Exception: gen_log.warning("error on read", exc_info=True) self.close(exc_info=True) return - if self._read_from_buffer(): + if pos is not None: + self._read_from_buffer(pos) return else: self._maybe_run_close_callback() def _set_read_callback(self, callback): - assert not self._read_callback, "Already reading" - self._read_callback = stack_context.wrap(callback) + assert self._read_callback is None, "Already reading" + assert self._read_future is None, "Already reading" + if callback is not None: + self._read_callback = stack_context.wrap(callback) + else: + self._read_future = TracebackFuture() + return self._read_future + + def _run_read_callback(self, size, streaming): + if streaming: + callback = self._streaming_callback + else: + callback = self._read_callback + self._read_callback = self._streaming_callback = None + if self._read_future is not None: + assert callback is None + future = self._read_future + self._read_future = None + future.set_result(self._consume(size)) + if callback is not None: + assert self._read_future is None + self._run_callback(callback, self._consume(size)) + else: + # If we scheduled a callback, we will add the error listener + # afterwards. If we didn't, we have to do it now. + self._maybe_add_error_listener() def _try_inline_read(self): """Attempt to complete the current read operation from buffered data. @@ -426,18 +634,14 @@ class BaseIOStream(object): listening for reads on the socket. """ # See if we've already got the data from a previous read - if self._read_from_buffer(): + self._run_streaming_callback() + pos = self._find_read_pos() + if pos is not None: + self._read_from_buffer(pos) return self._check_closed() try: - try: - # See comments in _handle_read about incrementing _pending_callbacks - self._pending_callbacks += 1 - while not self.closed(): - if self._read_to_buffer() == 0: - break - finally: - self._pending_callbacks -= 1 + pos = self._read_to_buffer_loop() except Exception: # If there was an in _read_to_buffer, we called close() already, # but couldn't run the close callback because of _pending_callbacks. @@ -445,9 +649,15 @@ class BaseIOStream(object): # applicable. self._maybe_run_close_callback() raise - if self._read_from_buffer(): + if pos is not None: + self._read_from_buffer(pos) return - self._maybe_add_error_listener() + # We couldn't satisfy the read inline, so either close the stream + # or listen for new data. + if self.closed(): + self._maybe_run_close_callback() + else: + self._add_io_state(ioloop.IOLoop.READ) def _read_to_buffer(self): """Reads from the socket and appends the result to the read buffer. @@ -472,32 +682,42 @@ class BaseIOStream(object): return 0 self._read_buffer.append(chunk) self._read_buffer_size += len(chunk) - if self._read_buffer_size >= self.max_buffer_size: + if self._read_buffer_size > self.max_buffer_size: gen_log.error("Reached maximum read buffer size") self.close() - raise IOError("Reached maximum read buffer size") + raise StreamBufferFullError("Reached maximum read buffer size") return len(chunk) - def _read_from_buffer(self): - """Attempts to complete the currently-pending read from the buffer. - - Returns True if the read was completed. - """ + def _run_streaming_callback(self): if self._streaming_callback is not None and self._read_buffer_size: bytes_to_consume = self._read_buffer_size if self._read_bytes is not None: bytes_to_consume = min(self._read_bytes, bytes_to_consume) self._read_bytes -= bytes_to_consume - self._run_callback(self._streaming_callback, - self._consume(bytes_to_consume)) - if self._read_bytes is not None and self._read_buffer_size >= self._read_bytes: - num_bytes = self._read_bytes - callback = self._read_callback - self._read_callback = None - self._streaming_callback = None - self._read_bytes = None - self._run_callback(callback, self._consume(num_bytes)) - return True + self._run_read_callback(bytes_to_consume, True) + + def _read_from_buffer(self, pos): + """Attempts to complete the currently-pending read from the buffer. + + The argument is either a position in the read buffer or None, + as returned by _find_read_pos. + """ + self._read_bytes = self._read_delimiter = self._read_regex = None + self._read_partial = False + self._run_read_callback(pos, False) + + def _find_read_pos(self): + """Attempts to find a position in the read buffer that satisfies + the currently-pending read. + + Returns a position in the buffer if the current read can be satisfied, + or None if it cannot. + """ + if (self._read_bytes is not None and + (self._read_buffer_size >= self._read_bytes or + (self._read_partial and self._read_buffer_size > 0))): + num_bytes = min(self._read_bytes, self._read_buffer_size) + return num_bytes elif self._read_delimiter is not None: # Multi-byte delimiters (e.g. '\r\n') may straddle two # chunks in the read buffer, so we can't easily find them @@ -506,37 +726,40 @@ class BaseIOStream(object): # length) tend to be "line" oriented, the delimiter is likely # to be in the first few chunks. Merge the buffer gradually # since large merges are relatively expensive and get undone in - # consume(). + # _consume(). if self._read_buffer: while True: loc = self._read_buffer[0].find(self._read_delimiter) if loc != -1: - callback = self._read_callback delimiter_len = len(self._read_delimiter) - self._read_callback = None - self._streaming_callback = None - self._read_delimiter = None - self._run_callback(callback, - self._consume(loc + delimiter_len)) - return True + self._check_max_bytes(self._read_delimiter, + loc + delimiter_len) + return loc + delimiter_len if len(self._read_buffer) == 1: break _double_prefix(self._read_buffer) + self._check_max_bytes(self._read_delimiter, + len(self._read_buffer[0])) elif self._read_regex is not None: if self._read_buffer: while True: m = self._read_regex.search(self._read_buffer[0]) if m is not None: - callback = self._read_callback - self._read_callback = None - self._streaming_callback = None - self._read_regex = None - self._run_callback(callback, self._consume(m.end())) - return True + self._check_max_bytes(self._read_regex, m.end()) + return m.end() if len(self._read_buffer) == 1: break _double_prefix(self._read_buffer) - return False + self._check_max_bytes(self._read_regex, + len(self._read_buffer[0])) + return None + + def _check_max_bytes(self, delimiter, size): + if (self._read_max_bytes is not None and + size > self._read_max_bytes): + raise UnsatisfiableReadError( + "delimiter %r not found within %d bytes" % ( + delimiter, self._read_max_bytes)) def _handle_write(self): while self._write_buffer: @@ -563,6 +786,7 @@ class BaseIOStream(object): self._write_buffer_frozen = False _merge_prefix(self._write_buffer, num_bytes) self._write_buffer.popleft() + self._write_buffer_size -= num_bytes except (socket.error, IOError, OSError) as e: if e.args[0] in _ERRNO_WOULDBLOCK: self._write_buffer_frozen = True @@ -572,14 +796,19 @@ class BaseIOStream(object): # Broken pipe errors are usually caused by connection # reset, and its better to not log EPIPE errors to # minimize log spam - gen_log.warning("Write error on %d: %s", + gen_log.warning("Write error on %s: %s", self.fileno(), e) self.close(exc_info=True) return - if not self._write_buffer and self._write_callback: - callback = self._write_callback - self._write_callback = None - self._run_callback(callback) + if not self._write_buffer: + if self._write_callback: + callback = self._write_callback + self._write_callback = None + self._run_callback(callback) + if self._write_future: + future = self._write_future + self._write_future = None + future.set_result(None) def _consume(self, loc): if loc == 0: @@ -593,10 +822,19 @@ class BaseIOStream(object): raise StreamClosedError("Stream is closed") def _maybe_add_error_listener(self): - if self._state is None and self._pending_callbacks == 0: + # This method is part of an optimization: to detect a connection that + # is closed when we're not actively reading or writing, we must listen + # for read events. However, it is inefficient to do this when the + # connection is first established because we are going to read or write + # immediately anyway. Instead, we insert checks at various times to + # see if the connection is idle and add the read listener then. + if self._pending_callbacks != 0: + return + if self._state is None or self._state == ioloop.IOLoop.ERROR: if self.closed(): self._maybe_run_close_callback() - else: + elif (self._read_buffer_size == 0 and + self._close_callback is not None): self._add_io_state(ioloop.IOLoop.READ) def _add_io_state(self, state): @@ -680,7 +918,7 @@ class IOStream(BaseIOStream): super(IOStream, self).__init__(*args, **kwargs) def fileno(self): - return self.socket.fileno() + return self.socket def close_fd(self): self.socket.close() @@ -712,9 +950,19 @@ class IOStream(BaseIOStream): May only be called if the socket passed to the constructor was not previously connected. The address parameter is in the - same format as for `socket.connect `, - i.e. a ``(host, port)`` tuple. If ``callback`` is specified, - it will be called when the connection is completed. + same format as for `socket.connect ` for + the type of socket passed to the IOStream constructor, + e.g. an ``(ip, port)`` tuple. Hostnames are accepted here, + but will be resolved synchronously and block the IOLoop. + If you have a hostname instead of an IP address, the `.TCPClient` + class is recommended instead of calling this method directly. + `.TCPClient` will do asynchronous DNS resolution and handle + both IPv4 and IPv6. + + If ``callback`` is specified, it will be called with no + arguments when the connection is completed; if not this method + returns a `.Future` (whose result after a successful + connection will be the stream itself). If specified, the ``server_hostname`` parameter will be used in SSL connections for certificate validation (if requested in @@ -726,6 +974,10 @@ class IOStream(BaseIOStream): which case the data will be written as soon as the connection is ready. Calling `IOStream` read methods before the socket is connected works on some platforms but is non-portable. + + .. versionchanged:: 4.0 + If no callback is given, returns a `.Future`. + """ self._connecting = True try: @@ -738,14 +990,83 @@ class IOStream(BaseIOStream): # returned immediately when attempting to connect to # localhost, so handle them the same way as an error # reported later in _handle_connect. - if (e.args[0] != errno.EINPROGRESS and - e.args[0] not in _ERRNO_WOULDBLOCK): - gen_log.warning("Connect error on fd %d: %s", + if (errno_from_exception(e) != errno.EINPROGRESS and + errno_from_exception(e) not in _ERRNO_WOULDBLOCK): + gen_log.warning("Connect error on fd %s: %s", self.socket.fileno(), e) self.close(exc_info=True) return - self._connect_callback = stack_context.wrap(callback) + if callback is not None: + self._connect_callback = stack_context.wrap(callback) + future = None + else: + future = self._connect_future = TracebackFuture() self._add_io_state(self.io_loop.WRITE) + return future + + def start_tls(self, server_side, ssl_options=None, server_hostname=None): + """Convert this `IOStream` to an `SSLIOStream`. + + This enables protocols that begin in clear-text mode and + switch to SSL after some initial negotiation (such as the + ``STARTTLS`` extension to SMTP and IMAP). + + This method cannot be used if there are outstanding reads + or writes on the stream, or if there is any data in the + IOStream's buffer (data in the operating system's socket + buffer is allowed). This means it must generally be used + immediately after reading or writing the last clear-text + data. It can also be used immediately after connecting, + before any reads or writes. + + The ``ssl_options`` argument may be either a dictionary + of options or an `ssl.SSLContext`. If a ``server_hostname`` + is given, it will be used for certificate verification + (as configured in the ``ssl_options``). + + This method returns a `.Future` whose result is the new + `SSLIOStream`. After this method has been called, + any other operation on the original stream is undefined. + + If a close callback is defined on this stream, it will be + transferred to the new stream. + + .. versionadded:: 4.0 + """ + if (self._read_callback or self._read_future or + self._write_callback or self._write_future or + self._connect_callback or self._connect_future or + self._pending_callbacks or self._closed or + self._read_buffer or self._write_buffer): + raise ValueError("IOStream is not idle; cannot convert to SSL") + if ssl_options is None: + ssl_options = {} + + socket = self.socket + self.io_loop.remove_handler(socket) + self.socket = None + socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side, + do_handshake_on_connect=False) + orig_close_callback = self._close_callback + self._close_callback = None + + future = TracebackFuture() + ssl_stream = SSLIOStream(socket, ssl_options=ssl_options, + io_loop=self.io_loop) + # Wrap the original close callback so we can fail our Future as well. + # If we had an "unwrap" counterpart to this method we would need + # to restore the original callback after our Future resolves + # so that repeated wrap/unwrap calls don't build up layers. + def close_callback(): + if not future.done(): + future.set_exception(ssl_stream.error or StreamClosedError()) + if orig_close_callback is not None: + orig_close_callback() + ssl_stream.set_close_callback(close_callback) + ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream) + ssl_stream.max_buffer_size = self.max_buffer_size + ssl_stream.read_chunk_size = self.read_chunk_size + return future def _handle_connect(self): err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) @@ -755,14 +1076,19 @@ class IOStream(BaseIOStream): # an error state before the socket becomes writable, so # in that case a connection failure would be handled by the # error path in _handle_events instead of here. - gen_log.warning("Connect error on fd %d: %s", - self.socket.fileno(), errno.errorcode[err]) + if self._connect_future is None: + gen_log.warning("Connect error on fd %s: %s", + self.socket.fileno(), errno.errorcode[err]) self.close() return if self._connect_callback is not None: callback = self._connect_callback self._connect_callback = None self._run_callback(callback) + if self._connect_future is not None: + future = self._connect_future + self._connect_future = None + future.set_result(self) self._connecting = False def set_nodelay(self, value): @@ -841,7 +1167,7 @@ class SSLIOStream(IOStream): peer = self.socket.getpeername() except Exception: peer = '(not connected)' - gen_log.warning("SSL Error on %d %s: %s", + gen_log.warning("SSL Error on %s %s: %s", self.socket.fileno(), peer, err) return self.close(exc_info=True) raise @@ -907,19 +1233,33 @@ class SSLIOStream(IOStream): # has completed. self._ssl_connect_callback = stack_context.wrap(callback) self._server_hostname = server_hostname - super(SSLIOStream, self).connect(address, callback=None) + # Note: Since we don't pass our callback argument along to + # super.connect(), this will always return a Future. + # This is harmless, but a bit less efficient than it could be. + return super(SSLIOStream, self).connect(address, callback=None) def _handle_connect(self): + # Call the superclass method to check for errors. + super(SSLIOStream, self)._handle_connect() + if self.closed(): + return # When the connection is complete, wrap the socket for SSL # traffic. Note that we do this by overriding _handle_connect # instead of by passing a callback to super().connect because # user callbacks are enqueued asynchronously on the IOLoop, # but since _handle_events calls _handle_connect immediately # followed by _handle_write we need this to be synchronous. + # + # The IOLoop will get confused if we swap out self.socket while the + # fd is registered, so remove it now and re-register after + # wrap_socket(). + self.io_loop.remove_handler(self.socket) + old_state = self._state + self._state = None self.socket = ssl_wrap_socket(self.socket, self._ssl_options, server_hostname=self._server_hostname, do_handshake_on_connect=False) - super(SSLIOStream, self)._handle_connect() + self._add_io_state(old_state) def read_from_fd(self): if self._ssl_accepting: @@ -978,9 +1318,9 @@ class PipeIOStream(BaseIOStream): try: chunk = os.read(self.fd, self.read_chunk_size) except (IOError, OSError) as e: - if e.args[0] in _ERRNO_WOULDBLOCK: + if errno_from_exception(e) in _ERRNO_WOULDBLOCK: return None - elif e.args[0] == errno.EBADF: + elif errno_from_exception(e) == errno.EBADF: # If the writing half of a pipe is closed, select will # report it as readable but reads will fail with EBADF. self.close(exc_info=True) diff --git a/tornado/log.py b/tornado/log.py index 36c3dd40..70664664 100644 --- a/tornado/log.py +++ b/tornado/log.py @@ -83,10 +83,10 @@ class LogFormatter(logging.Formatter): DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s' DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S' DEFAULT_COLORS = { - logging.DEBUG: 4, # Blue - logging.INFO: 2, # Green - logging.WARNING: 3, # Yellow - logging.ERROR: 1, # Red + logging.DEBUG: 4, # Blue + logging.INFO: 2, # Green + logging.WARNING: 3, # Yellow + logging.ERROR: 1, # Red } def __init__(self, color=True, fmt=DEFAULT_FORMAT, @@ -184,7 +184,7 @@ def enable_pretty_logging(options=None, logger=None): """ if options is None: from tornado.options import options - if options.logging == 'none': + if options.logging is None or options.logging.lower() == 'none': return if logger is None: logger = logging.getLogger() diff --git a/tornado/netutil.py b/tornado/netutil.py index 8ebe604d..a9e05d1e 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -20,18 +20,26 @@ from __future__ import absolute_import, division, print_function, with_statement import errno import os +import platform import socket -import ssl import stat from tornado.concurrent import dummy_executor, run_on_executor from tornado.ioloop import IOLoop from tornado.platform.auto import set_close_exec -from tornado.util import u, Configurable +from tornado.util import u, Configurable, errno_from_exception + +try: + import ssl +except ImportError: + # ssl is not available on Google App Engine + ssl = None if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+ ssl_match_hostname = ssl.match_hostname SSLCertificateError = ssl.CertificateError +elif ssl is None: + ssl_match_hostname = SSLCertificateError = None else: import backports.ssl_match_hostname ssl_match_hostname = backports.ssl_match_hostname.match_hostname @@ -44,6 +52,11 @@ else: # thread now. u('foo').encode('idna') +# These errnos indicate that a non-blocking operation must be retried +# at a later time. On most platforms they're the same value, but on +# some they differ. +_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) + def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None): """Creates listening sockets bound to the given port and address. @@ -77,13 +90,23 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags family = socket.AF_INET if flags is None: flags = socket.AI_PASSIVE + bound_port = None for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags)): af, socktype, proto, canonname, sockaddr = res + if (platform.system() == 'Darwin' and address == 'localhost' and + af == socket.AF_INET6 and sockaddr[3] != 0): + # Mac OS X includes a link-local address fe80::1%lo0 in the + # getaddrinfo results for 'localhost'. However, the firewall + # doesn't understand that this is a local address and will + # prompt for access (often repeatedly, due to an apparent + # bug in its ability to remember granting access to an + # application). Skip these addresses. + continue try: sock = socket.socket(af, socktype, proto) except socket.error as e: - if e.args[0] == errno.EAFNOSUPPORT: + if errno_from_exception(e) == errno.EAFNOSUPPORT: continue raise set_close_exec(sock.fileno()) @@ -100,8 +123,16 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags # Python 2.x on windows doesn't have IPPROTO_IPV6. if hasattr(socket, "IPPROTO_IPV6"): sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + + # automatic port allocation with port=None + # should bind on the same port on IPv4 and IPv6 + host, requested_port = sockaddr[:2] + if requested_port == 0 and bound_port is not None: + sockaddr = tuple([host, bound_port] + list(sockaddr[2:])) + sock.setblocking(0) sock.bind(sockaddr) + bound_port = sock.getsockname()[1] sock.listen(backlog) sockets.append(sock) return sockets @@ -124,7 +155,7 @@ if hasattr(socket, 'AF_UNIX'): try: st = os.stat(file) except OSError as err: - if err.errno != errno.ENOENT: + if errno_from_exception(err) != errno.ENOENT: raise else: if stat.S_ISSOCK(st.st_mode): @@ -154,18 +185,18 @@ def add_accept_handler(sock, callback, io_loop=None): try: connection, address = sock.accept() except socket.error as e: - # EWOULDBLOCK and EAGAIN indicate we have accepted every + # _ERRNO_WOULDBLOCK indicate we have accepted every # connection that is available. - if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): + if errno_from_exception(e) in _ERRNO_WOULDBLOCK: return # ECONNABORTED indicates that there was a connection # but it was closed while still in the accept queue. # (observed on FreeBSD). - if e.args[0] == errno.ECONNABORTED: + if errno_from_exception(e) == errno.ECONNABORTED: continue raise callback(connection, address) - io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ) + io_loop.add_handler(sock, accept_handler, IOLoop.READ) def is_valid_ip(ip): @@ -381,6 +412,10 @@ def ssl_options_to_context(ssl_options): context.load_verify_locations(ssl_options['ca_certs']) if 'ciphers' in ssl_options: context.set_ciphers(ssl_options['ciphers']) + if hasattr(ssl, 'OP_NO_COMPRESSION'): + # Disable TLS compression to avoid CRIME and related attacks. + # This constant wasn't added until python 3.3. + context.options |= ssl.OP_NO_COMPRESSION return context diff --git a/tornado/options.py b/tornado/options.py index 1105c0e9..fa9c269e 100644 --- a/tornado/options.py +++ b/tornado/options.py @@ -56,6 +56,18 @@ We support `datetimes `, `timedeltas the top-level functions in this module (`define`, `parse_command_line`, etc) simply call methods on it. You may create additional `OptionParser` instances to define isolated sets of options, such as for subcommands. + +.. note:: + + By default, several options are defined that will configure the + standard `logging` module when `parse_command_line` or `parse_config_file` + are called. If you want Tornado to leave the logging configuration + alone so you can manage it yourself, either pass ``--logging=none`` + on the command line or do the following to disable it in code:: + + from tornado.options import options, parse_command_line + options.logging = None + parse_command_line() """ from __future__ import absolute_import, division, print_function, with_statement @@ -360,6 +372,8 @@ class _Mockable(object): class _Option(object): + UNSET = object() + def __init__(self, name, default=None, type=basestring_type, help=None, metavar=None, multiple=False, file_name=None, group_name=None, callback=None): @@ -374,10 +388,10 @@ class _Option(object): self.group_name = group_name self.callback = callback self.default = default - self._value = None + self._value = _Option.UNSET def value(self): - return self.default if self._value is None else self._value + return self.default if self._value is _Option.UNSET else self._value def parse(self, value): _parse = { diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index 5d8f3073..6518dea5 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -12,9 +12,9 @@ unfinished callbacks on the event loop that fail when it resumes) from __future__ import absolute_import, division, print_function, with_statement import datetime import functools -import os -from tornado.ioloop import IOLoop +# _Timeout is used for its timedelta_to_seconds method for py26 compatibility. +from tornado.ioloop import IOLoop, _Timeout from tornado import stack_context try: @@ -34,7 +34,7 @@ class BaseAsyncIOLoop(IOLoop): self.asyncio_loop = asyncio_loop self.close_loop = close_loop self.asyncio_loop.call_soon(self.make_current) - # Maps fd to handler function (as in IOLoop.add_handler) + # Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler) self.handlers = {} # Set of fds listening for reads/writes self.readers = set() @@ -44,19 +44,18 @@ class BaseAsyncIOLoop(IOLoop): def close(self, all_fds=False): self.closing = True for fd in list(self.handlers): + fileobj, handler_func = self.handlers[fd] self.remove_handler(fd) if all_fds: - try: - os.close(fd) - except OSError: - pass + self.close_fd(fileobj) if self.close_loop: self.asyncio_loop.close() def add_handler(self, fd, handler, events): + fd, fileobj = self.split_fd(fd) if fd in self.handlers: - raise ValueError("fd %d added twice" % fd) - self.handlers[fd] = stack_context.wrap(handler) + raise ValueError("fd %s added twice" % fd) + self.handlers[fd] = (fileobj, stack_context.wrap(handler)) if events & IOLoop.READ: self.asyncio_loop.add_reader( fd, self._handle_events, fd, IOLoop.READ) @@ -67,6 +66,7 @@ class BaseAsyncIOLoop(IOLoop): self.writers.add(fd) def update_handler(self, fd, events): + fd, fileobj = self.split_fd(fd) if events & IOLoop.READ: if fd not in self.readers: self.asyncio_loop.add_reader( @@ -87,6 +87,7 @@ class BaseAsyncIOLoop(IOLoop): self.writers.remove(fd) def remove_handler(self, fd): + fd, fileobj = self.split_fd(fd) if fd not in self.handlers: return if fd in self.readers: @@ -98,7 +99,8 @@ class BaseAsyncIOLoop(IOLoop): del self.handlers[fd] def _handle_events(self, fd, events): - self.handlers[fd](fd, events) + fileobj, handler_func = self.handlers[fd] + handler_func(fileobj, events) def start(self): self._setup_logging() @@ -107,17 +109,11 @@ class BaseAsyncIOLoop(IOLoop): def stop(self): self.asyncio_loop.stop() - def _run_callback(self, callback, *args, **kwargs): - try: - callback(*args, **kwargs) - except Exception: - self.handle_callback_exception(callback) - def add_timeout(self, deadline, callback): if isinstance(deadline, (int, float)): delay = max(deadline - self.time(), 0) elif isinstance(deadline, datetime.timedelta): - delay = deadline.total_seconds() + delay = _Timeout.timedelta_to_seconds(deadline) else: raise TypeError("Unsupported deadline %r", deadline) return self.asyncio_loop.call_later(delay, self._run_callback, @@ -129,13 +125,9 @@ class BaseAsyncIOLoop(IOLoop): def add_callback(self, callback, *args, **kwargs): if self.closing: raise RuntimeError("IOLoop is closing") - if kwargs: - self.asyncio_loop.call_soon_threadsafe(functools.partial( - self._run_callback, stack_context.wrap(callback), - *args, **kwargs)) - else: - self.asyncio_loop.call_soon_threadsafe( - self._run_callback, stack_context.wrap(callback), *args) + self.asyncio_loop.call_soon_threadsafe( + self._run_callback, + functools.partial(stack_context.wrap(callback), *args, **kwargs)) add_callback_from_signal = add_callback diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py index e55725b3..ddfe06b4 100644 --- a/tornado/platform/auto.py +++ b/tornado/platform/auto.py @@ -30,6 +30,10 @@ import os if os.name == 'nt': from tornado.platform.common import Waker from tornado.platform.windows import set_close_exec +elif 'APPENGINE_RUNTIME' in os.environ: + from tornado.platform.common import Waker + def set_close_exec(fd): + pass else: from tornado.platform.posix import set_close_exec, Waker diff --git a/tornado/platform/common.py b/tornado/platform/common.py index d9c4cf9f..b409a903 100644 --- a/tornado/platform/common.py +++ b/tornado/platform/common.py @@ -15,7 +15,8 @@ class Waker(interface.Waker): and Jython. """ def __init__(self): - # Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py + # Based on Zope select_trigger.py: + # https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py self.writer = socket.socket() # Disable buffering -- pulling the trigger sends 1 byte, diff --git a/tornado/platform/kqueue.py b/tornado/platform/kqueue.py index ceff0a43..de8c046d 100644 --- a/tornado/platform/kqueue.py +++ b/tornado/platform/kqueue.py @@ -37,7 +37,7 @@ class _KQueue(object): def register(self, fd, events): if fd in self._active: - raise IOError("fd %d already registered" % fd) + raise IOError("fd %s already registered" % fd) self._control(fd, events, select.KQ_EV_ADD) self._active[fd] = events diff --git a/tornado/platform/select.py b/tornado/platform/select.py index 8bbb1f4f..9a879562 100644 --- a/tornado/platform/select.py +++ b/tornado/platform/select.py @@ -37,7 +37,7 @@ class _Select(object): def register(self, fd, events): if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds: - raise IOError("fd %d already registered" % fd) + raise IOError("fd %s already registered" % fd) if events & IOLoop.READ: self.read_fds.add(fd) if events & IOLoop.WRITE: diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 0c8a3105..18263dd9 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -91,6 +91,11 @@ from tornado.netutil import Resolver from tornado.stack_context import NullContext, wrap from tornado.ioloop import IOLoop +try: + long # py2 +except NameError: + long = int # py3 + @implementer(IDelayedCall) class TornadoDelayedCall(object): @@ -365,8 +370,9 @@ def install(io_loop=None): @implementer(IReadDescriptor, IWriteDescriptor) class _FD(object): - def __init__(self, fd, handler): + def __init__(self, fd, fileobj, handler): self.fd = fd + self.fileobj = fileobj self.handler = handler self.reading = False self.writing = False @@ -377,15 +383,15 @@ class _FD(object): def doRead(self): if not self.lost: - self.handler(self.fd, tornado.ioloop.IOLoop.READ) + self.handler(self.fileobj, tornado.ioloop.IOLoop.READ) def doWrite(self): if not self.lost: - self.handler(self.fd, tornado.ioloop.IOLoop.WRITE) + self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE) def connectionLost(self, reason): if not self.lost: - self.handler(self.fd, tornado.ioloop.IOLoop.ERROR) + self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR) self.lost = True def logPrefix(self): @@ -412,14 +418,19 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): self.reactor.callWhenRunning(self.make_current) def close(self, all_fds=False): + fds = self.fds self.reactor.removeAll() for c in self.reactor.getDelayedCalls(): c.cancel() + if all_fds: + for fd in fds.values(): + self.close_fd(fd.fileobj) def add_handler(self, fd, handler, events): if fd in self.fds: - raise ValueError('fd %d added twice' % fd) - self.fds[fd] = _FD(fd, wrap(handler)) + raise ValueError('fd %s added twice' % fd) + fd, fileobj = self.split_fd(fd) + self.fds[fd] = _FD(fd, fileobj, wrap(handler)) if events & tornado.ioloop.IOLoop.READ: self.fds[fd].reading = True self.reactor.addReader(self.fds[fd]) @@ -428,6 +439,7 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): self.reactor.addWriter(self.fds[fd]) def update_handler(self, fd, events): + fd, fileobj = self.split_fd(fd) if events & tornado.ioloop.IOLoop.READ: if not self.fds[fd].reading: self.fds[fd].reading = True @@ -446,6 +458,7 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): self.reactor.removeWriter(self.fds[fd]) def remove_handler(self, fd): + fd, fileobj = self.split_fd(fd) if fd not in self.fds: return self.fds[fd].lost = True @@ -462,12 +475,6 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): def stop(self): self.reactor.crash() - def _run_callback(self, callback, *args, **kwargs): - try: - callback(*args, **kwargs) - except Exception: - self.handle_callback_exception(callback) - def add_timeout(self, deadline, callback): if isinstance(deadline, (int, long, float)): delay = max(deadline - self.time(), 0) @@ -482,8 +489,9 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): timeout.cancel() def add_callback(self, callback, *args, **kwargs): - self.reactor.callFromThread(self._run_callback, - wrap(callback), *args, **kwargs) + self.reactor.callFromThread( + self._run_callback, + functools.partial(wrap(callback), *args, **kwargs)) def add_callback_from_signal(self, callback, *args, **kwargs): self.add_callback(callback, *args, **kwargs) diff --git a/tornado/process.py b/tornado/process.py index 942c5c3f..0f38b856 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -21,7 +21,6 @@ the server into multiple processes and managing subprocesses. from __future__ import absolute_import, division, print_function, with_statement import errno -import multiprocessing import os import signal import subprocess @@ -35,6 +34,13 @@ from tornado.iostream import PipeIOStream from tornado.log import gen_log from tornado.platform.auto import set_close_exec from tornado import stack_context +from tornado.util import errno_from_exception + +try: + import multiprocessing +except ImportError: + # Multiprocessing is not availble on Google App Engine. + multiprocessing = None try: long # py2 @@ -44,6 +50,8 @@ except NameError: def cpu_count(): """Returns the number of processors on this machine.""" + if multiprocessing is None: + return 1 try: return multiprocessing.cpu_count() except NotImplementedError: @@ -136,7 +144,7 @@ def fork_processes(num_processes, max_restarts=100): try: pid, status = os.wait() except OSError as e: - if e.errno == errno.EINTR: + if errno_from_exception(e) == errno.EINTR: continue raise if pid not in children: @@ -283,7 +291,7 @@ class Subprocess(object): try: ret_pid, status = os.waitpid(pid, os.WNOHANG) except OSError as e: - if e.args[0] == errno.ECHILD: + if errno_from_exception(e) == errno.ECHILD: return if ret_pid == 0: return diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 73bfee89..06d7ecfa 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -1,23 +1,23 @@ #!/usr/bin/env python from __future__ import absolute_import, division, print_function, with_statement -from tornado.escape import utf8, _unicode, native_str +from tornado.concurrent import is_future +from tornado.escape import utf8, _unicode from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy -from tornado.httputil import HTTPHeaders -from tornado.iostream import IOStream, SSLIOStream +from tornado import httputil +from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters +from tornado.iostream import StreamClosedError from tornado.netutil import Resolver, OverrideResolver from tornado.log import gen_log from tornado import stack_context -from tornado.util import GzipDecompressor +from tornado.tcpclient import TCPClient import base64 import collections import copy import functools -import os.path import re import socket -import ssl import sys try: @@ -30,7 +30,23 @@ try: except ImportError: import urllib.parse as urlparse # py3 -_DEFAULT_CA_CERTS = os.path.dirname(__file__) + '/ca-certificates.crt' +try: + import ssl +except ImportError: + # ssl is not available on Google App Engine. + ssl = None + +try: + import certifi +except ImportError: + certifi = None + + +def _default_ca_certs(): + if certifi is None: + raise Exception("The 'certifi' package is required to use https " + "in simple_httpclient") + return certifi.where() class SimpleAsyncHTTPClient(AsyncHTTPClient): @@ -47,7 +63,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """ def initialize(self, io_loop, max_clients=10, hostname_mapping=None, max_buffer_size=104857600, - resolver=None, defaults=None): + resolver=None, defaults=None, max_header_size=None): """Creates a AsyncHTTPClient. Only a single AsyncHTTPClient instance exists per IOLoop @@ -74,6 +90,9 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.active = {} self.waiting = {} self.max_buffer_size = max_buffer_size + self.max_header_size = max_header_size + # TCPClient could create a Resolver for us, but we have to do it + # ourselves to support hostname_mapping. if resolver: self.resolver = resolver self.own_resolver = False @@ -83,11 +102,13 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): if hostname_mapping is not None: self.resolver = OverrideResolver(resolver=self.resolver, mapping=hostname_mapping) + self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop) def close(self): super(SimpleAsyncHTTPClient, self).close() if self.own_resolver: self.resolver.close() + self.tcp_client.close() def fetch_impl(self, request, callback): key = object() @@ -119,7 +140,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): def _handle_request(self, request, release_callback, final_callback): _HTTPConnection(self.io_loop, self, request, release_callback, - final_callback, self.max_buffer_size, self.resolver) + final_callback, self.max_buffer_size, self.tcp_client, + self.max_header_size) def _release_fetch(self, key): del self.active[key] @@ -142,11 +164,12 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): del self.waiting[key] -class _HTTPConnection(object): +class _HTTPConnection(httputil.HTTPMessageDelegate): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, - final_callback, max_buffer_size, resolver): + final_callback, max_buffer_size, tcp_client, + max_header_size): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client @@ -154,13 +177,15 @@ class _HTTPConnection(object): self.release_callback = release_callback self.final_callback = final_callback self.max_buffer_size = max_buffer_size - self.resolver = resolver + self.tcp_client = tcp_client + self.max_header_size = max_header_size self.code = None self.headers = None - self.chunks = None + self.chunks = [] self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None + self._sockaddr = None with stack_context.ExceptionStackContext(self._handle_exception): self.parsed = urlparse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): @@ -183,42 +208,31 @@ class _HTTPConnection(object): host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect - if request.allow_ipv6: - af = socket.AF_UNSPEC - else: - # We only try the first IP we get from getaddrinfo, - # so restrict to ipv4 by default. + if request.allow_ipv6 is False: af = socket.AF_INET + else: + af = socket.AF_UNSPEC + + ssl_options = self._get_ssl_options(self.parsed.scheme) timeout = min(self.request.connect_timeout, self.request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, stack_context.wrap(self._on_timeout)) - self.resolver.resolve(host, port, af, callback=self._on_resolve) + self.tcp_client.connect(host, port, af=af, + ssl_options=ssl_options, + callback=self._on_connect) - def _on_resolve(self, addrinfo): - if self.final_callback is None: - # final_callback is cleared if we've hit our timeout - return - self.stream = self._create_stream(addrinfo) - self.stream.set_close_callback(self._on_close) - # ipv6 addresses are broken (in self.parsed.hostname) until - # 2.7, here is correctly parsed value calculated in __init__ - sockaddr = addrinfo[0][1] - self.stream.connect(sockaddr, self._on_connect, - server_hostname=self.parsed_hostname) - - def _create_stream(self, addrinfo): - af = addrinfo[0][0] - if self.parsed.scheme == "https": + def _get_ssl_options(self, scheme): + if scheme == "https": ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs else: - ssl_options["ca_certs"] = _DEFAULT_CA_CERTS + ssl_options["ca_certs"] = _default_ca_certs() if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key if self.request.client_cert is not None: @@ -236,21 +250,16 @@ class _HTTPConnection(object): # but nearly all servers support both SSLv3 and TLSv1: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): - ssl_options["ciphers"] = "DEFAULT:!SSLv2" + # In addition to disabling SSLv2, we also exclude certain + # classes of insecure ciphers. + ssl_options["ciphers"] = "DEFAULT:!SSLv2:!EXPORT:!DES" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1 - - return SSLIOStream(socket.socket(af), - io_loop=self.io_loop, - ssl_options=ssl_options, - max_buffer_size=self.max_buffer_size) - else: - return IOStream(socket.socket(af), - io_loop=self.io_loop, - max_buffer_size=self.max_buffer_size) + return ssl_options + return None def _on_timeout(self): self._timeout = None @@ -262,7 +271,13 @@ class _HTTPConnection(object): self.io_loop.remove_timeout(self._timeout) self._timeout = None - def _on_connect(self): + def _on_connect(self, stream): + if self.final_callback is None: + # final_callback is cleared if we've hit our timeout. + stream.close() + return + self.stream = stream + self.stream.set_close_callback(self._on_close) self._remove_timeout() if self.final_callback is None: return @@ -302,16 +317,22 @@ class _HTTPConnection(object): self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): - if self.request.body is None: + if (self.request.body is None and + self.request.body_producer is None): raise AssertionError( 'Body must not be empty for "%s" request' % self.request.method) else: - if self.request.body is not None: + if (self.request.body is not None or + self.request.body_producer is not None): raise AssertionError( 'Body must be empty for "%s" request' % self.request.method) + if self.request.expect_100_continue: + self.request.headers["Expect"] = "100-continue" if self.request.body is not None: + # When body_producer is used the caller is responsible for + # setting Content-Length (or else chunked encoding will be used). self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and @@ -320,20 +341,47 @@ class _HTTPConnection(object): if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = ((self.parsed.path or '/') + - (('?' + self.parsed.query) if self.parsed.query else '')) - request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, - req_path))] - for k, v in self.request.headers.get_all(): - line = utf8(k) + b": " + utf8(v) - if b'\n' in line: - raise ValueError('Newline in header: ' + repr(line)) - request_lines.append(line) - request_str = b"\r\n".join(request_lines) + b"\r\n\r\n" - if self.request.body is not None: - request_str += self.request.body + (('?' + self.parsed.query) if self.parsed.query else '')) self.stream.set_nodelay(True) - self.stream.write(request_str) - self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) + self.connection = HTTP1Connection( + self.stream, True, + HTTP1ConnectionParameters( + no_keep_alive=True, + max_header_size=self.max_header_size, + use_gzip=self.request.use_gzip), + self._sockaddr) + start_line = httputil.RequestStartLine(self.request.method, + req_path, 'HTTP/1.1') + self.connection.write_headers(start_line, self.request.headers) + if self.request.expect_100_continue: + self._read_response() + else: + self._write_body(True) + + def _write_body(self, start_read): + if self.request.body is not None: + self.connection.write(self.request.body) + self.connection.finish() + elif self.request.body_producer is not None: + fut = self.request.body_producer(self.connection.write) + if is_future(fut): + def on_body_written(fut): + fut.result() + self.connection.finish() + if start_read: + self._read_response() + self.io_loop.add_future(fut, on_body_written) + return + self.connection.finish() + if start_read: + self._read_response() + + def _read_response(self): + # Ensure that any exception raised in read_response ends up in our + # stack context. + self.io_loop.add_future( + self.connection.read_response(self), + lambda f: f.result()) def _release(self): if self.release_callback is not None: @@ -351,43 +399,39 @@ class _HTTPConnection(object): def _handle_exception(self, typ, value, tb): if self.final_callback: self._remove_timeout() + if isinstance(value, StreamClosedError): + value = HTTPError(599, "Stream closed") self._run_callback(HTTPResponse(self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time, )) if hasattr(self, "stream"): + # TODO: this may cause a StreamClosedError to be raised + # by the connection's Future. Should we cancel the + # connection more gracefully? self.stream.close() return True else: # If our callback has already been called, we are probably # catching an exception that is not caused by us but rather # some child of our callback. Rather than drop it on the floor, - # pass it along. - return False + # pass it along, unless it's just the stream being closed. + return isinstance(value, StreamClosedError) def _on_close(self): if self.final_callback is not None: message = "Connection closed" if self.stream.error: - message = str(self.stream.error) + raise self.stream.error raise HTTPError(599, message) - def _handle_1xx(self, code): - self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) - - def _on_headers(self, data): - data = native_str(data.decode("latin1")) - first_line, _, header_data = data.partition("\n") - match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line) - assert match - code = int(match.group(1)) - self.headers = HTTPHeaders.parse(header_data) - if 100 <= code < 200: - self._handle_1xx(code) + def headers_received(self, first_line, headers): + if self.request.expect_100_continue and first_line.code == 100: + self._write_body(False) return - else: - self.code = code - self.reason = match.group(2) + self.headers = headers + self.code = first_line.code + self.reason = first_line.reason if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: @@ -404,17 +448,12 @@ class _HTTPConnection(object): content_length = None if self.request.header_callback is not None: - # re-attach the newline we split on earlier - self.request.header_callback(first_line + _) + # Reassemble the start line. + self.request.header_callback('%s %s %s\r\n' % first_line) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) self.request.header_callback('\r\n') - if self.request.method == "HEAD" or self.code == 304: - # HEAD requests and 304 responses never have content, even - # though they may have content-length headers - self._on_body(b"") - return if 100 <= self.code < 200 or self.code == 204: # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 @@ -422,21 +461,9 @@ class _HTTPConnection(object): content_length not in (None, 0)): raise ValueError("Response with code %d should not have body" % self.code) - self._on_body(b"") - return - if (self.request.use_gzip and - self.headers.get("Content-Encoding") == "gzip"): - self._decompressor = GzipDecompressor() - if self.headers.get("Transfer-Encoding") == "chunked": - self.chunks = [] - self.stream.read_until(b"\r\n", self._on_chunk_length) - elif content_length is not None: - self.stream.read_bytes(content_length, self._on_body) - else: - self.stream.read_until_close(self._on_body) - - def _on_body(self, data): + def finish(self): + data = b''.join(self.chunks) self._remove_timeout() original_request = getattr(self.request, "original_request", self.request) @@ -472,19 +499,12 @@ class _HTTPConnection(object): self.client.fetch(new_request, final_callback) self._on_end_request() return - if self._decompressor: - data = (self._decompressor.decompress(data) + - self._decompressor.flush()) if self.request.streaming_callback: - if self.chunks is None: - # if chunks is not None, we already called streaming_callback - # in _on_chunk_data - self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse(original_request, - self.code, reason=self.reason, + self.code, reason=getattr(self, 'reason', None), headers=self.headers, request_time=self.io_loop.time() - self.start_time, buffer=buffer, @@ -495,40 +515,11 @@ class _HTTPConnection(object): def _on_end_request(self): self.stream.close() - def _on_chunk_length(self, data): - # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 - length = int(data.strip(), 16) - if length == 0: - if self._decompressor is not None: - tail = self._decompressor.flush() - if tail: - # I believe the tail will always be empty (i.e. - # decompress will return all it can). The purpose - # of the flush call is to detect errors such - # as truncated input. But in case it ever returns - # anything, treat it as an extra chunk - if self.request.streaming_callback is not None: - self.request.streaming_callback(tail) - else: - self.chunks.append(tail) - # all the data has been decompressed, so we don't need to - # decompress again in _on_body - self._decompressor = None - self._on_body(b''.join(self.chunks)) - else: - self.stream.read_bytes(length + 2, # chunk ends with \r\n - self._on_chunk_data) - - def _on_chunk_data(self, data): - assert data[-2:] == b"\r\n" - chunk = data[:-2] - if self._decompressor: - chunk = self._decompressor.decompress(chunk) + def data_received(self, chunk): if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) - self.stream.read_until(b"\r\n", self._on_chunk_length) if __name__ == "__main__": diff --git a/tornado/stack_context.py b/tornado/stack_context.py index b1e82b0e..2e845ab2 100644 --- a/tornado/stack_context.py +++ b/tornado/stack_context.py @@ -266,6 +266,18 @@ def wrap(fn): # TODO: Any other better way to store contexts and update them in wrapped function? cap_contexts = [_state.contexts] + if not cap_contexts[0][0] and not cap_contexts[0][1]: + # Fast path when there are no active contexts. + def null_wrapper(*args, **kwargs): + try: + current_state = _state.contexts + _state.contexts = cap_contexts[0] + return fn(*args, **kwargs) + finally: + _state.contexts = current_state + null_wrapper._wrapped = True + return null_wrapper + def wrapped(*args, **kwargs): ret = None try: diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py new file mode 100644 index 00000000..d49eb5cd --- /dev/null +++ b/tornado/tcpclient.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A non-blocking TCP connection factory. +""" +from __future__ import absolute_import, division, print_function, with_statement + +import functools +import socket + +from tornado.concurrent import Future +from tornado.ioloop import IOLoop +from tornado.iostream import IOStream +from tornado import gen +from tornado.netutil import Resolver + +_INITIAL_CONNECT_TIMEOUT = 0.3 + + +class _Connector(object): + """A stateless implementation of the "Happy Eyeballs" algorithm. + + "Happy Eyeballs" is documented in RFC6555 as the recommended practice + for when both IPv4 and IPv6 addresses are available. + + In this implementation, we partition the addresses by family, and + make the first connection attempt to whichever address was + returned first by ``getaddrinfo``. If that connection fails or + times out, we begin a connection in parallel to the first address + of the other family. If there are additional failures we retry + with other addresses, keeping one connection attempt per family + in flight at a time. + + http://tools.ietf.org/html/rfc6555 + + """ + def __init__(self, addrinfo, io_loop, connect): + self.io_loop = io_loop + self.connect = connect + + self.future = Future() + self.timeout = None + self.last_error = None + self.remaining = len(addrinfo) + self.primary_addrs, self.secondary_addrs = self.split(addrinfo) + + @staticmethod + def split(addrinfo): + """Partition the ``addrinfo`` list by address family. + + Returns two lists. The first list contains the first entry from + ``addrinfo`` and all others with the same family, and the + second list contains all other addresses (normally one list will + be AF_INET and the other AF_INET6, although non-standard resolvers + may return additional families). + """ + primary = [] + secondary = [] + primary_af = addrinfo[0][0] + for af, addr in addrinfo: + if af == primary_af: + primary.append((af, addr)) + else: + secondary.append((af, addr)) + return primary, secondary + + def start(self, timeout=_INITIAL_CONNECT_TIMEOUT): + self.try_connect(iter(self.primary_addrs)) + self.set_timout(timeout) + return self.future + + def try_connect(self, addrs): + try: + af, addr = next(addrs) + except StopIteration: + # We've reached the end of our queue, but the other queue + # might still be working. Send a final error on the future + # only when both queues are finished. + if self.remaining == 0 and not self.future.done(): + self.future.set_exception(self.last_error or + IOError("connection failed")) + return + future = self.connect(af, addr) + future.add_done_callback(functools.partial(self.on_connect_done, + addrs, af, addr)) + + def on_connect_done(self, addrs, af, addr, future): + self.remaining -= 1 + try: + stream = future.result() + except Exception as e: + if self.future.done(): + return + # Error: try again (but remember what happened so we have an + # error to raise in the end) + self.last_error = e + self.try_connect(addrs) + if self.timeout is not None: + # If the first attempt failed, don't wait for the + # timeout to try an address from the secondary queue. + self.on_timeout() + return + self.clear_timeout() + if self.future.done(): + # This is a late arrival; just drop it. + stream.close() + else: + self.future.set_result((af, addr, stream)) + + def set_timout(self, timeout): + self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, + self.on_timeout) + + def on_timeout(self): + self.timeout = None + self.try_connect(iter(self.secondary_addrs)) + + def clear_timeout(self): + if self.timeout is not None: + self.io_loop.remove_timeout(self.timeout) + + +class TCPClient(object): + """A non-blocking TCP connection factory. + """ + def __init__(self, resolver=None, io_loop=None): + self.io_loop = io_loop or IOLoop.current() + if resolver is not None: + self.resolver = resolver + self._own_resolver = False + else: + self.resolver = Resolver(io_loop=io_loop) + self._own_resolver = True + + def close(self): + if self._own_resolver: + self.resolver.close() + + @gen.coroutine + def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, + max_buffer_size=None): + """Connect to the given host and port. + + Asynchronously returns an `.IOStream` (or `.SSLIOStream` if + ``ssl_options`` is not None). + """ + addrinfo = yield self.resolver.resolve(host, port, af) + connector = _Connector( + addrinfo, self.io_loop, + functools.partial(self._create_stream, max_buffer_size)) + af, addr, stream = yield connector.start() + # TODO: For better performance we could cache the (af, addr) + # information here and re-use it on sbusequent connections to + # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) + if ssl_options is not None: + stream = yield stream.start_tls(False, ssl_options=ssl_options, + server_hostname=host) + raise gen.Return(stream) + + def _create_stream(self, max_buffer_size, af, addr): + # Always connect in plaintext; we'll convert to ssl if necessary + # after one connection has completed. + stream = IOStream(socket.socket(af), + io_loop=self.io_loop, + max_buffer_size=max_buffer_size) + return stream.connect(addr) diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 9370cbac..427acec5 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -20,13 +20,19 @@ from __future__ import absolute_import, division, print_function, with_statement import errno import os import socket -import ssl from tornado.log import app_log from tornado.ioloop import IOLoop from tornado.iostream import IOStream, SSLIOStream from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket from tornado import process +from tornado.util import errno_from_exception + +try: + import ssl +except ImportError: + # ssl is not available on Google App Engine. + ssl = None class TCPServer(object): @@ -81,13 +87,15 @@ class TCPServer(object): .. versionadded:: 3.1 The ``max_buffer_size`` argument. """ - def __init__(self, io_loop=None, ssl_options=None, max_buffer_size=None): + def __init__(self, io_loop=None, ssl_options=None, max_buffer_size=None, + read_chunk_size=None): self.io_loop = io_loop self.ssl_options = ssl_options self._sockets = {} # fd -> socket object self._pending_sockets = [] self._started = False self.max_buffer_size = max_buffer_size + self.read_chunk_size = None # Verify the SSL options. Otherwise we don't get errors until clients # connect. This doesn't verify that the keys are legitimate, but @@ -231,15 +239,19 @@ class TCPServer(object): # SSLIOStream._do_ssl_handshake). # To test this behavior, try nmap with the -sT flag. # https://github.com/tornadoweb/tornado/pull/750 - if err.args[0] in (errno.ECONNABORTED, errno.EINVAL): + if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL): return connection.close() else: raise try: if self.ssl_options is not None: - stream = SSLIOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size) + stream = SSLIOStream(connection, io_loop=self.io_loop, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size) else: - stream = IOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size) + stream = IOStream(connection, io_loop=self.io_loop, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size) self.handle_stream(stream, address) except Exception: app_log.error("Error in connection callback", exc_info=True) diff --git a/tornado/template.py b/tornado/template.py index db5a528d..4dcec5d5 100644 --- a/tornado/template.py +++ b/tornado/template.py @@ -180,7 +180,7 @@ with ``{# ... #}``. ``{% set *x* = *y* %}`` Sets a local variable. -``{% try %}...{% except %}...{% finally %}...{% else %}...{% end %}`` +``{% try %}...{% except %}...{% else %}...{% finally %}...{% end %}`` Same as the python ``try`` statement. ``{% while *condition* %}... {% end %}`` @@ -367,10 +367,9 @@ class Loader(BaseLoader): def _create_template(self, name): path = os.path.join(self.root, name) - f = open(path, "rb") - template = Template(f.read(), name=name, loader=self) - f.close() - return template + with open(path, "rb") as f: + template = Template(f.read(), name=name, loader=self) + return template class DictLoader(BaseLoader): @@ -785,7 +784,7 @@ def _parse(reader, template, in_block=None, in_loop=None): if allowed_parents is not None: if not in_block: raise ParseError("%s outside %s block" % - (operator, allowed_parents)) + (operator, allowed_parents)) if in_block not in allowed_parents: raise ParseError("%s block cannot be attached to %s block" % (operator, in_block)) body.chunks.append(_IntermediateControlBlock(contents, line)) diff --git a/tornado/test/__main__.py b/tornado/test/__main__.py new file mode 100644 index 00000000..5953443b --- /dev/null +++ b/tornado/test/__main__.py @@ -0,0 +1,14 @@ +"""Shim to allow python -m tornado.test. + +This only works in python 2.7+. +""" +from __future__ import absolute_import, division, print_function, with_statement + +from tornado.test.runtests import all, main + +# tornado.testing.main autodiscovery relies on 'all' being present in +# the main module, so import it here even though it is not used directly. +# The following line prevents a pyflakes warning. +all = all + +main() diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 1d6cb839..254e1ae1 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -67,11 +67,29 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin): self.finish(user) def _oauth_get_user(self, access_token, callback): + if self.get_argument('fail_in_get_user', None): + raise Exception("failing in get_user") if access_token != dict(key='uiop', secret='5678'): raise Exception("incorrect access token %r" % access_token) callback(dict(email='foo@example.com')) +class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler): + """Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine.""" + @gen.coroutine + def get(self): + if self.get_argument('oauth_token', None): + # Ensure that any exceptions are set on the returned Future, + # not simply thrown into the surrounding StackContext. + try: + yield self.get_authenticated_user() + except Exception as e: + self.set_status(503) + self.write("got exception: %s" % e) + else: + yield self.authorize_redirect() + + class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin): def initialize(self, version): self._OAUTH_VERSION = version @@ -255,6 +273,9 @@ class AuthTest(AsyncHTTPTestCase): dict(version='1.0')), ('/oauth10a/client/login', OAuth1ClientLoginHandler, dict(test=self, version='1.0a')), + ('/oauth10a/client/login_coroutine', + OAuth1ClientLoginCoroutineHandler, + dict(test=self, version='1.0a')), ('/oauth10a/client/request_params', OAuth1ClientRequestParametersHandler, dict(version='1.0a')), @@ -348,6 +369,12 @@ class AuthTest(AsyncHTTPTestCase): self.assertTrue('oauth_nonce' in parsed) self.assertTrue('oauth_signature' in parsed) + def test_oauth10a_get_user_coroutine_exception(self): + response = self.fetch( + '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true', + headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + self.assertEqual(response.code, 503) + def test_oauth2_redirect(self): response = self.fetch('/oauth2/client/login', follow_redirects=False) self.assertEqual(response.code, 302) diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py index 849337ed..5e93ad6a 100644 --- a/tornado/test/concurrent_test.py +++ b/tornado/test/concurrent_test.py @@ -28,7 +28,6 @@ from tornado.iostream import IOStream from tornado import stack_context from tornado.tcpserver import TCPServer from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test -from tornado.test.util import unittest try: @@ -113,13 +112,6 @@ class ReturnFutureTest(AsyncTestCase): self.assertIs(future, future2) self.assertEqual(future.result(), 42) - @unittest.skipIf(futures is None, "futures module not present") - def test_timeout_future(self): - with self.assertRaises(futures.TimeoutError): - future = self.async_future() - # Do not call self.wait() - future.result(timeout=.1) - @gen_test def test_async_future_gen(self): result = yield self.async_future() diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py index fb696564..3873cf1e 100644 --- a/tornado/test/curl_httpclient_test.py +++ b/tornado/test/curl_httpclient_test.py @@ -68,6 +68,16 @@ class DigestAuthHandler(RequestHandler): (realm, nonce, opaque)) +class CustomReasonHandler(RequestHandler): + def get(self): + self.set_status(200, "Custom reason") + + +class CustomFailReasonHandler(RequestHandler): + def get(self): + self.set_status(400, "Custom reason") + + @unittest.skipIf(pycurl is None, "pycurl module not present") class CurlHTTPClientTestCase(AsyncHTTPTestCase): def setUp(self): @@ -78,6 +88,8 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase): def get_app(self): return Application([ ('/digest', DigestAuthHandler), + ('/custom_reason', CustomReasonHandler), + ('/custom_fail_reason', CustomFailReasonHandler), ]) def test_prepare_curl_callback_stack_context(self): @@ -100,3 +112,11 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase): response = self.fetch('/digest', auth_mode='digest', auth_username='foo', auth_password='bar') self.assertEqual(response.body, b'ok') + + def test_custom_reason(self): + response = self.fetch('/custom_reason') + self.assertEqual(response.reason, "Custom reason") + + def test_fail_custom_reason(self): + response = self.fetch('/custom_fail_reason') + self.assertEqual(str(response.error), "HTTP 400: Custom reason") diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 5a463f81..a15cdf73 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function, with_statement import contextlib +import datetime import functools import sys import textwrap @@ -8,7 +9,7 @@ import time import platform import weakref -from tornado.concurrent import return_future +from tornado.concurrent import return_future, Future from tornado.escape import url_escape from tornado.httpclient import AsyncHTTPClient from tornado.ioloop import IOLoop @@ -20,6 +21,10 @@ from tornado.web import Application, RequestHandler, asynchronous, HTTPError from tornado import gen +try: + from concurrent import futures +except ImportError: + futures = None skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available') skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython', @@ -291,26 +296,53 @@ class GenEngineTest(AsyncTestCase): self.stop() self.run_gen(f) - def test_multi_delayed(self): + # The following tests explicitly run with both gen.Multi + # and gen.multi_future (Task returns a Future, so it can be used + # with either). + def test_multi_yieldpoint_delayed(self): @gen.engine def f(): # callbacks run at different times - responses = yield [ + responses = yield gen.Multi([ gen.Task(self.delay_callback, 3, arg="v1"), gen.Task(self.delay_callback, 1, arg="v2"), - ] + ]) self.assertEqual(responses, ["v1", "v2"]) self.stop() self.run_gen(f) - def test_multi_dict_delayed(self): + def test_multi_yieldpoint_dict_delayed(self): @gen.engine def f(): # callbacks run at different times - responses = yield dict( + responses = yield gen.Multi(dict( foo=gen.Task(self.delay_callback, 3, arg="v1"), bar=gen.Task(self.delay_callback, 1, arg="v2"), - ) + )) + self.assertEqual(responses, dict(foo="v1", bar="v2")) + self.stop() + self.run_gen(f) + + def test_multi_future_delayed(self): + @gen.engine + def f(): + # callbacks run at different times + responses = yield gen.multi_future([ + gen.Task(self.delay_callback, 3, arg="v1"), + gen.Task(self.delay_callback, 1, arg="v2"), + ]) + self.assertEqual(responses, ["v1", "v2"]) + self.stop() + self.run_gen(f) + + def test_multi_future_dict_delayed(self): + @gen.engine + def f(): + # callbacks run at different times + responses = yield gen.multi_future(dict( + foo=gen.Task(self.delay_callback, 3, arg="v1"), + bar=gen.Task(self.delay_callback, 1, arg="v2"), + )) self.assertEqual(responses, dict(foo="v1", bar="v2")) self.stop() self.run_gen(f) @@ -334,6 +366,15 @@ class GenEngineTest(AsyncTestCase): y = yield {} self.assertTrue(isinstance(y, dict)) + @gen_test + def test_multi_mixed_types(self): + # A YieldPoint (Wait) and Future (Task) can be combined + # (and use the YieldPoint codepath) + (yield gen.Callback("k1"))("v1") + responses = yield [gen.Wait("k1"), + gen.Task(self.delay_callback, 3, arg="v2")] + self.assertEqual(responses, ["v1", "v2"]) + @gen_test def test_future(self): result = yield self.async_future(1) @@ -733,8 +774,14 @@ class GenCoroutineTest(AsyncTestCase): def test_replace_context_exception(self): # Test exception handling: exceptions thrown into the stack context # can be caught and replaced. + # Note that this test and the following are for behavior that is + # not really supported any more: coroutines no longer create a + # stack context automatically; but one is created after the first + # YieldPoint (i.e. not a Future). @gen.coroutine def f2(): + (yield gen.Callback(1))() + yield gen.Wait(1) self.io_loop.add_callback(lambda: 1 / 0) try: yield gen.Task(self.io_loop.add_timeout, @@ -753,6 +800,8 @@ class GenCoroutineTest(AsyncTestCase): # can be caught and ignored. @gen.coroutine def f2(): + (yield gen.Callback(1))() + yield gen.Wait(1) self.io_loop.add_callback(lambda: 1 / 0) try: yield gen.Task(self.io_loop.add_timeout, @@ -764,6 +813,31 @@ class GenCoroutineTest(AsyncTestCase): self.assertEqual(result, 42) self.finished = True + @gen_test + def test_moment(self): + calls = [] + @gen.coroutine + def f(name, yieldable): + for i in range(5): + calls.append(name) + yield yieldable + # First, confirm the behavior without moment: each coroutine + # monopolizes the event loop until it finishes. + immediate = Future() + immediate.set_result(None) + yield [f('a', immediate), f('b', immediate)] + self.assertEqual(''.join(calls), 'aaaaabbbbb') + + # With moment, they take turns. + calls = [] + yield [f('a', gen.moment), f('b', gen.moment)] + self.assertEqual(''.join(calls), 'ababababab') + self.finished = True + + calls = [] + yield [f('a', gen.moment), f('b', immediate)] + self.assertEqual(''.join(calls), 'abbbbbaaaa') + class GenSequenceHandler(RequestHandler): @asynchronous @@ -943,5 +1017,55 @@ class GenWebTest(AsyncHTTPTestCase): response = self.fetch('/async_prepare_error') self.assertEqual(response.code, 403) + +class WithTimeoutTest(AsyncTestCase): + @gen_test + def test_timeout(self): + with self.assertRaises(gen.TimeoutError): + yield gen.with_timeout(datetime.timedelta(seconds=0.1), + Future()) + + @gen_test + def test_completes_before_timeout(self): + future = Future() + self.io_loop.add_timeout(datetime.timedelta(seconds=0.1), + lambda: future.set_result('asdf')) + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), + future) + self.assertEqual(result, 'asdf') + + @gen_test + def test_fails_before_timeout(self): + future = Future() + self.io_loop.add_timeout( + datetime.timedelta(seconds=0.1), + lambda: future.set_exception(ZeroDivisionError)) + with self.assertRaises(ZeroDivisionError): + yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + + @gen_test + def test_already_resolved(self): + future = Future() + future.set_result('asdf') + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), + future) + self.assertEqual(result, 'asdf') + + @unittest.skipIf(futures is None, 'futures module not present') + @gen_test + def test_timeout_concurrent_future(self): + with futures.ThreadPoolExecutor(1) as executor: + with self.assertRaises(gen.TimeoutError): + yield gen.with_timeout(self.io_loop.time(), + executor.submit(time.sleep, 0.1)) + + @unittest.skipIf(futures is None, 'futures module not present') + @gen_test + def test_completed_concurrent_future(self): + with futures.ThreadPoolExecutor(1) as executor: + yield gen.with_timeout(datetime.timedelta(seconds=3600), + executor.submit(lambda: None)) + + if __name__ == '__main__': unittest.main() diff --git a/tornado/test/gettext_translations/extract_me.py b/tornado/test/gettext_translations/extract_me.py new file mode 100644 index 00000000..75406ecc --- /dev/null +++ b/tornado/test/gettext_translations/extract_me.py @@ -0,0 +1,11 @@ +# Dummy source file to allow creation of the initial .po file in the +# same way as a real project. I'm not entirely sure about the real +# workflow here, but this seems to work. +# +# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po +# 2) Edit tornado_test.po, setting CHARSET and setting msgstr +# 3) msgfmt tornado_test.po -o tornado_test.mo +# 4) Put the file in the proper location: $LANG/LC_MESSAGES + +from __future__ import absolute_import, division, print_function, with_statement +_("school") diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index 569ea872..78daa74d 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -8,7 +8,6 @@ from contextlib import closing import functools import sys import threading -import time from tornado.escape import utf8 from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient @@ -19,7 +18,7 @@ from tornado.log import gen_log from tornado import netutil from tornado.stack_context import ExceptionStackContext, NullContext from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog -from tornado.test.util import unittest +from tornado.test.util import unittest, skipOnTravis from tornado.util import u, bytes_type from tornado.web import Application, RequestHandler, url @@ -111,6 +110,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): url("/all_methods", AllMethodsHandler), ], gzip=True) + @skipOnTravis def test_hello_world(self): response = self.fetch("/hello") self.assertEqual(response.code, 200) @@ -356,11 +356,10 @@ Transfer-Encoding: chunked @gen_test def test_future_http_error(self): - try: + with self.assertRaises(HTTPError) as context: yield self.http_client.fetch(self.get_url('/notfound')) - except HTTPError as e: - self.assertEqual(e.code, 404) - self.assertEqual(e.response.code, 404) + self.assertEqual(context.exception.code, 404) + self.assertEqual(context.exception.response.code, 404) @gen_test def test_reuse_request_from_response(self): diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 5ca29935..f5e5679d 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -2,20 +2,23 @@ from __future__ import absolute_import, division, print_function, with_statement -from tornado import httpclient, simple_httpclient, netutil -from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str +from tornado import netutil +from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str +from tornado import gen +from tornado.http1connection import HTTP1Connection from tornado.httpserver import HTTPServer -from tornado.httputil import HTTPHeaders +from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine from tornado.iostream import IOStream -from tornado.log import gen_log -from tornado.netutil import ssl_options_to_context, Resolver +from tornado.log import gen_log, app_log +from tornado.netutil import ssl_options_to_context from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog -from tornado.test.util import unittest +from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test +from tornado.test.util import unittest, skipOnTravis from tornado.util import u, bytes_type -from tornado.web import Application, RequestHandler, asynchronous +from tornado.web import Application, RequestHandler, asynchronous, stream_request_body from contextlib import closing import datetime +import gzip import os import shutil import socket @@ -23,6 +26,28 @@ import ssl import sys import tempfile +try: + from io import BytesIO # python 3 +except ImportError: + from cStringIO import StringIO as BytesIO # python 2 + + +def read_stream_body(stream, callback): + """Reads an HTTP response from `stream` and runs callback with its + headers and body.""" + chunks = [] + class Delegate(HTTPMessageDelegate): + def headers_received(self, start_line, headers): + self.headers = headers + + def data_received(self, chunk): + chunks.append(chunk) + + def finish(self): + callback((self.headers, b''.join(chunks))) + conn = HTTP1Connection(stream, True) + conn.read_response(Delegate()) + class HandlerBaseTestCase(AsyncHTTPTestCase): def get_app(self): @@ -86,11 +111,13 @@ class SSLTestMixin(object): # connection, rather than waiting for a timeout or otherwise # misbehaving. with ExpectLog(gen_log, '(SSL Error|uncaught exception)'): - self.http_client.fetch(self.get_url("/").replace('https:', 'http:'), - self.stop, - request_timeout=3600, - connect_timeout=3600) - response = self.wait() + with ExpectLog(gen_log, 'Uncaught exception', required=False): + self.http_client.fetch( + self.get_url("/").replace('https:', 'http:'), + self.stop, + request_timeout=3600, + connect_timeout=3600) + response = self.wait() self.assertEqual(response.code, 599) # Python's SSL implementation differs significantly between versions. @@ -163,18 +190,7 @@ class MultipartTestHandler(RequestHandler): }) -class RawRequestHTTPConnection(simple_httpclient._HTTPConnection): - def set_request(self, request): - self.__next_request = request - - def _on_connect(self): - self.stream.write(self.__next_request) - self.__next_request = None - self.stream.read_until(b"\r\n\r\n", self._on_headers) - # This test is also called from wsgi_test - - class HTTPConnectionTest(AsyncHTTPTestCase): def get_handlers(self): return [("/multipart", MultipartTestHandler), @@ -184,23 +200,16 @@ class HTTPConnectionTest(AsyncHTTPTestCase): return Application(self.get_handlers()) def raw_fetch(self, headers, body): - with closing(Resolver(io_loop=self.io_loop)) as resolver: - with closing(SimpleAsyncHTTPClient(self.io_loop, - resolver=resolver)) as client: - conn = RawRequestHTTPConnection( - self.io_loop, client, - httpclient._RequestProxy( - httpclient.HTTPRequest(self.get_url("/")), - dict(httpclient.HTTPRequest._DEFAULTS)), - None, self.stop, - 1024 * 1024, resolver) - conn.set_request( - b"\r\n".join(headers + - [utf8("Content-Length: %d\r\n" % len(body))]) + - b"\r\n" + body) - response = self.wait() - response.rethrow() - return response + with closing(IOStream(socket.socket())) as stream: + stream.connect(('127.0.0.1', self.get_http_port()), self.stop) + self.wait() + stream.write( + b"\r\n".join(headers + + [utf8("Content-Length: %d\r\n" % len(body))]) + + b"\r\n" + body) + read_stream_body(stream, self.stop) + headers, body = self.wait() + return body def test_multipart_form(self): # Encodings here are tricky: Headers are latin1, bodies can be @@ -221,7 +230,7 @@ class HTTPConnectionTest(AsyncHTTPTestCase): b"--1234567890--", b"", ])) - data = json_decode(response.body) + data = json_decode(response) self.assertEqual(u("\u00e9"), data["header"]) self.assertEqual(u("\u00e1"), data["argument"]) self.assertEqual(u("\u00f3"), data["filename"]) @@ -397,6 +406,25 @@ class HTTPServerRawTest(AsyncHTTPTestCase): self.stop) self.wait() + def test_chunked_request_body(self): + # Chunked requests are not widely supported and we don't have a way + # to generate them in AsyncHTTPClient, but HTTPServer will read them. + self.stream.write(b"""\ +POST /echo HTTP/1.1 +Transfer-Encoding: chunked +Content-Type: application/x-www-form-urlencoded + +4 +foo= +3 +bar +0 + +""".replace(b"\n", b"\r\n")) + read_stream_body(self.stream, self.stop) + headers, response = self.wait() + self.assertEqual(json_decode(response), {u('foo'): [u('bar')]}) + class XHeaderTest(HandlerBaseTestCase): class Handler(RequestHandler): @@ -541,7 +569,7 @@ class UnixSocketTest(AsyncTestCase): def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. - with ExpectLog(gen_log, "Malformed HTTP request from"): + with ExpectLog(gen_log, "Malformed HTTP message from"): self.stream.write(b"garbage\r\n\r\n") self.stream.read_until_close(self.stop) response = self.wait() @@ -610,8 +638,8 @@ class KeepAliveTest(AsyncHTTPTestCase): return headers def read_response(self): - headers = self.read_headers() - self.stream.read_bytes(int(headers['Content-Length']), self.stop) + self.headers = self.read_headers() + self.stream.read_bytes(int(self.headers['Content-Length']), self.stop) body = self.wait() self.assertEqual(b'Hello world', body) @@ -645,6 +673,7 @@ class KeepAliveTest(AsyncHTTPTestCase): self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) + self.assertTrue('Connection' not in self.headers) self.close() def test_http10_keepalive(self): @@ -652,8 +681,10 @@ class KeepAliveTest(AsyncHTTPTestCase): self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() + self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() + self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_pipelined_requests(self): @@ -683,3 +714,322 @@ class KeepAliveTest(AsyncHTTPTestCase): self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n') self.read_headers() self.close() + + +class GzipBaseTest(object): + def get_app(self): + return Application([('/', EchoHandler)]) + + def post_gzip(self, body): + bytesio = BytesIO() + gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio) + gzip_file.write(utf8(body)) + gzip_file.close() + compressed_body = bytesio.getvalue() + return self.fetch('/', method='POST', body=compressed_body, + headers={'Content-Encoding': 'gzip'}) + + def test_uncompressed(self): + response = self.fetch('/', method='POST', body='foo=bar') + self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]}) + + +class GzipTest(GzipBaseTest, AsyncHTTPTestCase): + def get_httpserver_options(self): + return dict(gzip=True) + + def test_gzip(self): + response = self.post_gzip('foo=bar') + self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]}) + + +class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase): + def test_gzip_unsupported(self): + # Gzip support is opt-in; without it the server fails to parse + # the body (but parsing form bodies is currently just a log message, + # not a fatal error). + with ExpectLog(gen_log, "Unsupported Content-Encoding"): + response = self.post_gzip('foo=bar') + self.assertEquals(json_decode(response.body), {}) + + +class StreamingChunkSizeTest(AsyncHTTPTestCase): + # 50 characters long, and repetitive so it can be compressed. + BODY = b'01234567890123456789012345678901234567890123456789' + CHUNK_SIZE = 16 + + def get_http_client(self): + # body_producer doesn't work on curl_httpclient, so override the + # configured AsyncHTTPClient implementation. + return SimpleAsyncHTTPClient(io_loop=self.io_loop) + + def get_httpserver_options(self): + return dict(chunk_size=self.CHUNK_SIZE, gzip=True) + + class MessageDelegate(HTTPMessageDelegate): + def __init__(self, connection): + self.connection = connection + + def headers_received(self, start_line, headers): + self.chunk_lengths = [] + + def data_received(self, chunk): + self.chunk_lengths.append(len(chunk)) + + def finish(self): + response_body = utf8(json_encode(self.chunk_lengths)) + self.connection.write_headers( + ResponseStartLine('HTTP/1.1', 200, 'OK'), + HTTPHeaders({'Content-Length': str(len(response_body))})) + self.connection.write(response_body) + self.connection.finish() + + def get_app(self): + class App(HTTPServerConnectionDelegate): + def start_request(self, connection): + return StreamingChunkSizeTest.MessageDelegate(connection) + return App() + + def fetch_chunk_sizes(self, **kwargs): + response = self.fetch('/', method='POST', **kwargs) + response.rethrow() + chunks = json_decode(response.body) + self.assertEqual(len(self.BODY), sum(chunks)) + for chunk_size in chunks: + self.assertLessEqual(chunk_size, self.CHUNK_SIZE, + 'oversized chunk: ' + str(chunks)) + self.assertGreater(chunk_size, 0, + 'empty chunk: ' + str(chunks)) + return chunks + + def compress(self, body): + bytesio = BytesIO() + gzfile = gzip.GzipFile(mode='w', fileobj=bytesio) + gzfile.write(body) + gzfile.close() + compressed = bytesio.getvalue() + if len(compressed) >= len(body): + raise Exception("body did not shrink when compressed") + return compressed + + def test_regular_body(self): + chunks = self.fetch_chunk_sizes(body=self.BODY) + # Without compression we know exactly what to expect. + self.assertEqual([16, 16, 16, 2], chunks) + + def test_compressed_body(self): + self.fetch_chunk_sizes(body=self.compress(self.BODY), + headers={'Content-Encoding': 'gzip'}) + # Compression creates irregular boundaries so the assertions + # in fetch_chunk_sizes are as specific as we can get. + + def test_chunked_body(self): + def body_producer(write): + write(self.BODY[:20]) + write(self.BODY[20:]) + chunks = self.fetch_chunk_sizes(body_producer=body_producer) + # HTTP chunk boundaries translate to application-visible breaks + self.assertEqual([16, 4, 16, 14], chunks) + + def test_chunked_compressed(self): + compressed = self.compress(self.BODY) + self.assertGreater(len(compressed), 20) + def body_producer(write): + write(compressed[:20]) + write(compressed[20:]) + self.fetch_chunk_sizes(body_producer=body_producer, + headers={'Content-Encoding': 'gzip'}) + + +class MaxHeaderSizeTest(AsyncHTTPTestCase): + def get_app(self): + return Application([('/', HelloWorldRequestHandler)]) + + def get_httpserver_options(self): + return dict(max_header_size=1024) + + def test_small_headers(self): + response = self.fetch("/", headers={'X-Filler': 'a' * 100}) + response.rethrow() + self.assertEqual(response.body, b"Hello world") + + def test_large_headers(self): + with ExpectLog(gen_log, "Unsatisfiable read"): + response = self.fetch("/", headers={'X-Filler': 'a' * 1000}) + self.assertEqual(response.code, 599) + + +@skipOnTravis +class IdleTimeoutTest(AsyncHTTPTestCase): + def get_app(self): + return Application([('/', HelloWorldRequestHandler)]) + + def get_httpserver_options(self): + return dict(idle_connection_timeout=0.1) + + def setUp(self): + super(IdleTimeoutTest, self).setUp() + self.streams = [] + + def tearDown(self): + super(IdleTimeoutTest, self).tearDown() + for stream in self.streams: + stream.close() + + def connect(self): + stream = IOStream(socket.socket()) + stream.connect(('localhost', self.get_http_port()), self.stop) + self.wait() + self.streams.append(stream) + return stream + + def test_unused_connection(self): + stream = self.connect() + stream.set_close_callback(self.stop) + self.wait() + + def test_idle_after_use(self): + stream = self.connect() + stream.set_close_callback(lambda: self.stop("closed")) + + # Use the connection twice to make sure keep-alives are working + for i in range(2): + stream.write(b"GET / HTTP/1.1\r\n\r\n") + stream.read_until(b"\r\n\r\n", self.stop) + self.wait() + stream.read_bytes(11, self.stop) + data = self.wait() + self.assertEqual(data, b"Hello world") + + # Now let the timeout trigger and close the connection. + data = self.wait() + self.assertEqual(data, "closed") + + +class BodyLimitsTest(AsyncHTTPTestCase): + def get_app(self): + class BufferedHandler(RequestHandler): + def put(self): + self.write(str(len(self.request.body))) + + @stream_request_body + class StreamingHandler(RequestHandler): + def initialize(self): + self.bytes_read = 0 + + def prepare(self): + if 'expected_size' in self.request.arguments: + self.request.connection.set_max_body_size( + int(self.get_argument('expected_size'))) + if 'body_timeout' in self.request.arguments: + self.request.connection.set_body_timeout( + float(self.get_argument('body_timeout'))) + + def data_received(self, data): + self.bytes_read += len(data) + + def put(self): + self.write(str(self.bytes_read)) + + return Application([('/buffered', BufferedHandler), + ('/streaming', StreamingHandler)]) + + def get_httpserver_options(self): + return dict(body_timeout=3600, max_body_size=4096) + + def get_http_client(self): + # body_producer doesn't work on curl_httpclient, so override the + # configured AsyncHTTPClient implementation. + return SimpleAsyncHTTPClient(io_loop=self.io_loop) + + def test_small_body(self): + response = self.fetch('/buffered', method='PUT', body=b'a' * 4096) + self.assertEqual(response.body, b'4096') + response = self.fetch('/streaming', method='PUT', body=b'a' * 4096) + self.assertEqual(response.body, b'4096') + + def test_large_body_buffered(self): + with ExpectLog(gen_log, '.*Content-Length too long'): + response = self.fetch('/buffered', method='PUT', body=b'a' * 10240) + self.assertEqual(response.code, 599) + + def test_large_body_buffered_chunked(self): + with ExpectLog(gen_log, '.*chunked body too large'): + response = self.fetch('/buffered', method='PUT', + body_producer=lambda write: write(b'a' * 10240)) + self.assertEqual(response.code, 599) + + def test_large_body_streaming(self): + with ExpectLog(gen_log, '.*Content-Length too long'): + response = self.fetch('/streaming', method='PUT', body=b'a' * 10240) + self.assertEqual(response.code, 599) + + def test_large_body_streaming_chunked(self): + with ExpectLog(gen_log, '.*chunked body too large'): + response = self.fetch('/streaming', method='PUT', + body_producer=lambda write: write(b'a' * 10240)) + self.assertEqual(response.code, 599) + + def test_large_body_streaming_override(self): + response = self.fetch('/streaming?expected_size=10240', method='PUT', + body=b'a' * 10240) + self.assertEqual(response.body, b'10240') + + def test_large_body_streaming_chunked_override(self): + response = self.fetch('/streaming?expected_size=10240', method='PUT', + body_producer=lambda write: write(b'a' * 10240)) + self.assertEqual(response.body, b'10240') + + @gen_test + def test_timeout(self): + stream = IOStream(socket.socket()) + try: + yield stream.connect(('127.0.0.1', self.get_http_port())) + # Use a raw stream because AsyncHTTPClient won't let us read a + # response without finishing a body. + stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n' + b'Content-Length: 42\r\n\r\n') + with ExpectLog(gen_log, 'Timeout reading body'): + response = yield stream.read_until_close() + self.assertEqual(response, b'') + finally: + stream.close() + + @gen_test + def test_body_size_override_reset(self): + # The max_body_size override is reset between requests. + stream = IOStream(socket.socket()) + try: + yield stream.connect(('127.0.0.1', self.get_http_port())) + # Use a raw stream so we can make sure it's all on one connection. + stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n' + b'Content-Length: 10240\r\n\r\n') + stream.write(b'a' * 10240) + headers, response = yield gen.Task(read_stream_body, stream) + self.assertEqual(response, b'10240') + # Without the ?expected_size parameter, we get the old default value + stream.write(b'PUT /streaming HTTP/1.1\r\n' + b'Content-Length: 10240\r\n\r\n') + with ExpectLog(gen_log, '.*Content-Length too long'): + data = yield stream.read_until_close() + self.assertEqual(data, b'') + finally: + stream.close() + + +class LegacyInterfaceTest(AsyncHTTPTestCase): + def get_app(self): + # The old request_callback interface does not implement the + # delegate interface, and writes its response via request.write + # instead of request.connection.write_headers. + def handle_request(request): + message = b"Hello world" + request.write(utf8("HTTP/1.1 200 OK\r\n" + "Content-Length: %d\r\n\r\n" % len(message))) + request.write(message) + request.finish() + return handle_request + + def test_legacy_interface(self): + response = self.fetch('/') + self.assertEqual(response.body, b"Hello world") diff --git a/tornado/test/import_test.py b/tornado/test/import_test.py index ccd6ef35..de7cc0b9 100644 --- a/tornado/test/import_test.py +++ b/tornado/test/import_test.py @@ -13,6 +13,7 @@ class ImportTest(unittest.TestCase): # import tornado.curl_httpclient # depends on pycurl import tornado.escape import tornado.gen + import tornado.http1connection import tornado.httpclient import tornado.httpserver import tornado.httputil diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py index fa863e61..e4f07338 100644 --- a/tornado/test/ioloop_test.py +++ b/tornado/test/ioloop_test.py @@ -5,16 +5,16 @@ from __future__ import absolute_import, division, print_function, with_statement import contextlib import datetime import functools -import logging import socket import sys import threading import time from tornado import gen -from tornado.ioloop import IOLoop, PollIOLoop, TimeoutError +from tornado.ioloop import IOLoop, TimeoutError +from tornado.log import app_log from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext -from tornado.testing import AsyncTestCase, bind_unused_port +from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis try: @@ -52,7 +52,8 @@ class TestIOLoop(AsyncTestCase): thread = threading.Thread(target=target) self.io_loop.add_callback(thread.start) self.wait() - self.assertAlmostEqual(time.time(), self.stop_time, places=2) + delta = time.time() - self.stop_time + self.assertLess(delta, 0.1) thread.join() def test_add_timeout_timedelta(self): @@ -172,6 +173,119 @@ class TestIOLoop(AsyncTestCase): self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop)) self.wait() + def test_close_file_object(self): + """When a file object is used instead of a numeric file descriptor, + the object should be closed (by IOLoop.close(all_fds=True), + not just the fd. + """ + # Use a socket since they are supported by IOLoop on all platforms. + # Unfortunately, sockets don't support the .closed attribute for + # inspecting their close status, so we must use a wrapper. + class SocketWrapper(object): + def __init__(self, sockobj): + self.sockobj = sockobj + self.closed = False + + def fileno(self): + return self.sockobj.fileno() + + def close(self): + self.closed = True + self.sockobj.close() + sockobj, port = bind_unused_port() + socket_wrapper = SocketWrapper(sockobj) + io_loop = IOLoop() + io_loop.add_handler(socket_wrapper, lambda fd, events: None, + IOLoop.READ) + io_loop.close(all_fds=True) + self.assertTrue(socket_wrapper.closed) + + def test_handler_callback_file_object(self): + """The handler callback receives the same fd object it passed in.""" + server_sock, port = bind_unused_port() + fds = [] + def handle_connection(fd, events): + fds.append(fd) + conn, addr = server_sock.accept() + conn.close() + self.stop() + self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ) + with contextlib.closing(socket.socket()) as client_sock: + client_sock.connect(('127.0.0.1', port)) + self.wait() + self.io_loop.remove_handler(server_sock) + self.io_loop.add_handler(server_sock.fileno(), handle_connection, + IOLoop.READ) + with contextlib.closing(socket.socket()) as client_sock: + client_sock.connect(('127.0.0.1', port)) + self.wait() + self.assertIs(fds[0], server_sock) + self.assertEqual(fds[1], server_sock.fileno()) + self.io_loop.remove_handler(server_sock.fileno()) + server_sock.close() + + def test_mixed_fd_fileobj(self): + server_sock, port = bind_unused_port() + def f(fd, events): + pass + self.io_loop.add_handler(server_sock, f, IOLoop.READ) + with self.assertRaises(Exception): + # The exact error is unspecified - some implementations use + # IOError, others use ValueError. + self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ) + self.io_loop.remove_handler(server_sock.fileno()) + server_sock.close() + + def test_reentrant(self): + """Calling start() twice should raise an error, not deadlock.""" + returned_from_start = [False] + got_exception = [False] + def callback(): + try: + self.io_loop.start() + returned_from_start[0] = True + except Exception: + got_exception[0] = True + self.stop() + self.io_loop.add_callback(callback) + self.wait() + self.assertTrue(got_exception[0]) + self.assertFalse(returned_from_start[0]) + + def test_exception_logging(self): + """Uncaught exceptions get logged by the IOLoop.""" + # Use a NullContext to keep the exception from being caught by + # AsyncTestCase. + with NullContext(): + self.io_loop.add_callback(lambda: 1/0) + self.io_loop.add_callback(self.stop) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + + def test_exception_logging_future(self): + """The IOLoop examines exceptions from Futures and logs them.""" + with NullContext(): + @gen.coroutine + def callback(): + self.io_loop.add_callback(self.stop) + 1/0 + self.io_loop.add_callback(callback) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + + def test_spawn_callback(self): + # An added callback runs in the test's stack_context, so will be + # re-arised in wait(). + self.io_loop.add_callback(lambda: 1/0) + with self.assertRaises(ZeroDivisionError): + self.wait() + # A spawned callback is run directly on the IOLoop, so it will be + # logged without stopping the test. + self.io_loop.spawn_callback(lambda: 1/0) + self.io_loop.add_callback(self.stop) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + # Deliberately not a subclass of AsyncTestCase so the IOLoop isn't # automatically set as current. diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 0675c4f7..e9d241a5 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -1,13 +1,16 @@ from __future__ import absolute_import, division, print_function, with_statement +from tornado.concurrent import Future +from tornado import gen from tornado import netutil -from tornado.ioloop import IOLoop -from tornado.iostream import IOStream, SSLIOStream, PipeIOStream +from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError +from tornado.httputil import HTTPHeaders from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket from tornado.stack_context import NullContext -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog +from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test from tornado.test.util import unittest, skipIfNonUnix from tornado.web import RequestHandler, Application +import certifi import errno import logging import os @@ -17,6 +20,13 @@ import ssl import sys +def _server_ssl_options(): + return dict( + certfile=os.path.join(os.path.dirname(__file__), 'test.crt'), + keyfile=os.path.join(os.path.dirname(__file__), 'test.key'), + ) + + class HelloHandler(RequestHandler): def get(self): self.write("Hello") @@ -106,6 +116,48 @@ class TestIOStreamWebMixin(object): stream.close() + @gen_test + def test_future_interface(self): + """Basic test of IOStream's ability to return Futures.""" + stream = self._make_client_iostream() + connect_result = yield stream.connect( + ("localhost", self.get_http_port())) + self.assertIs(connect_result, stream) + yield stream.write(b"GET / HTTP/1.0\r\n\r\n") + first_line = yield stream.read_until(b"\r\n") + self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n") + # callback=None is equivalent to no callback. + header_data = yield stream.read_until(b"\r\n\r\n", callback=None) + headers = HTTPHeaders.parse(header_data.decode('latin1')) + content_length = int(headers['Content-Length']) + body = yield stream.read_bytes(content_length) + self.assertEqual(body, b'Hello') + stream.close() + + @gen_test + def test_future_close_while_reading(self): + stream = self._make_client_iostream() + yield stream.connect(("localhost", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\n\r\n") + with self.assertRaises(StreamClosedError): + yield stream.read_bytes(1024 * 1024) + stream.close() + + @gen_test + def test_future_read_until_close(self): + # Ensure that the data comes through before the StreamClosedError. + stream = self._make_client_iostream() + yield stream.connect(("localhost", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n") + yield stream.read_until(b"\r\n\r\n") + body = yield stream.read_until_close() + self.assertEqual(body, b"Hello") + + # Nothing else to read; the error comes immediately without waiting + # for yield. + with self.assertRaises(StreamClosedError): + stream.read_bytes(1) + class TestIOStreamMixin(object): def _make_server_iostream(self, connection, **kwargs): @@ -158,9 +210,6 @@ class TestIOStreamMixin(object): server, client = self.make_iostream_pair() server.write(b'', callback=self.stop) self.wait() - # As a side effect, the stream is now listening for connection - # close (if it wasn't already), but is not listening for writes - self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR) server.close() client.close() @@ -298,6 +347,25 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_future_delayed_close_callback(self): + # Same as test_delayed_close_callback, but with the future interface. + server, client = self.make_iostream_pair() + # We can't call make_iostream_pair inside a gen_test function + # because the ioloop is not reentrant. + @gen_test + def f(self): + server.write(b"12") + chunks = [] + chunks.append((yield client.read_bytes(1))) + server.close() + chunks.append((yield client.read_bytes(1))) + self.assertEqual(chunks, [b"1", b"2"]) + try: + f(self) + finally: + server.close() + client.close() + def test_close_buffered_data(self): # Similar to the previous test, but with data stored in the OS's # socket buffers instead of the IOStream's read buffer. Out-of-band @@ -330,14 +398,18 @@ class TestIOStreamMixin(object): # Similar to test_delayed_close_callback, but read_until_close takes # a separate code path so test it separately. server, client = self.make_iostream_pair() - client.set_close_callback(self.stop) try: server.write(b"1234") server.close() - self.wait() + # Read one byte to make sure the client has received the data. + # It won't run the close callback as long as there is more buffered + # data that could satisfy a later read. + client.read_bytes(1, self.stop) + data = self.wait() + self.assertEqual(data, b"1") client.read_until_close(self.stop) data = self.wait() - self.assertEqual(data, b"1234") + self.assertEqual(data, b"234") finally: server.close() client.close() @@ -347,17 +419,18 @@ class TestIOStreamMixin(object): # All data should go through the streaming callback, # and the final read callback just gets an empty string. server, client = self.make_iostream_pair() - client.set_close_callback(self.stop) try: server.write(b"1234") server.close() - self.wait() + client.read_bytes(1, self.stop) + data = self.wait() + self.assertEqual(data, b"1") streaming_data = [] client.read_until_close(self.stop, streaming_callback=streaming_data.append) data = self.wait() self.assertEqual(b'', data) - self.assertEqual(b''.join(streaming_data), b"1234") + self.assertEqual(b''.join(streaming_data), b"234") finally: server.close() client.close() @@ -451,6 +524,203 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_future_close_callback(self): + # Regression test for interaction between the Future read interfaces + # and IOStream._maybe_add_error_listener. + server, client = self.make_iostream_pair() + closed = [False] + def close_callback(): + closed[0] = True + self.stop() + server.set_close_callback(close_callback) + try: + client.write(b'a') + future = server.read_bytes(1) + self.io_loop.add_future(future, self.stop) + self.assertEqual(self.wait().result(), b'a') + self.assertFalse(closed[0]) + client.close() + self.wait() + self.assertTrue(closed[0]) + finally: + server.close() + client.close() + + def test_read_bytes_partial(self): + server, client = self.make_iostream_pair() + try: + # Ask for more than is available with partial=True + client.read_bytes(50, self.stop, partial=True) + server.write(b"hello") + data = self.wait() + self.assertEqual(data, b"hello") + + # Ask for less than what is available; num_bytes is still + # respected. + client.read_bytes(3, self.stop, partial=True) + server.write(b"world") + data = self.wait() + self.assertEqual(data, b"wor") + + # Partial reads won't return an empty string, but read_bytes(0) + # will. + client.read_bytes(0, self.stop, partial=True) + data = self.wait() + self.assertEqual(data, b'') + finally: + server.close() + client.close() + + def test_read_until_max_bytes(self): + server, client = self.make_iostream_pair() + client.set_close_callback(lambda: self.stop("closed")) + try: + # Extra room under the limit + client.read_until(b"def", self.stop, max_bytes=50) + server.write(b"abcdef") + data = self.wait() + self.assertEqual(data, b"abcdef") + + # Just enough space + client.read_until(b"def", self.stop, max_bytes=6) + server.write(b"abcdef") + data = self.wait() + self.assertEqual(data, b"abcdef") + + # Not enough space, but we don't know it until all we can do is + # log a warning and close the connection. + with ExpectLog(gen_log, "Unsatisfiable read"): + client.read_until(b"def", self.stop, max_bytes=5) + server.write(b"123456") + data = self.wait() + self.assertEqual(data, "closed") + finally: + server.close() + client.close() + + def test_read_until_max_bytes_inline(self): + server, client = self.make_iostream_pair() + client.set_close_callback(lambda: self.stop("closed")) + try: + # Similar to the error case in the previous test, but the + # server writes first so client reads are satisfied + # inline. For consistency with the out-of-line case, we + # do not raise the error synchronously. + server.write(b"123456") + with ExpectLog(gen_log, "Unsatisfiable read"): + client.read_until(b"def", self.stop, max_bytes=5) + data = self.wait() + self.assertEqual(data, "closed") + finally: + server.close() + client.close() + + def test_read_until_max_bytes_ignores_extra(self): + server, client = self.make_iostream_pair() + client.set_close_callback(lambda: self.stop("closed")) + try: + # Even though data that matches arrives the same packet that + # puts us over the limit, we fail the request because it was not + # found within the limit. + server.write(b"abcdef") + with ExpectLog(gen_log, "Unsatisfiable read"): + client.read_until(b"def", self.stop, max_bytes=5) + data = self.wait() + self.assertEqual(data, "closed") + finally: + server.close() + client.close() + + def test_read_until_regex_max_bytes(self): + server, client = self.make_iostream_pair() + client.set_close_callback(lambda: self.stop("closed")) + try: + # Extra room under the limit + client.read_until_regex(b"def", self.stop, max_bytes=50) + server.write(b"abcdef") + data = self.wait() + self.assertEqual(data, b"abcdef") + + # Just enough space + client.read_until_regex(b"def", self.stop, max_bytes=6) + server.write(b"abcdef") + data = self.wait() + self.assertEqual(data, b"abcdef") + + # Not enough space, but we don't know it until all we can do is + # log a warning and close the connection. + with ExpectLog(gen_log, "Unsatisfiable read"): + client.read_until_regex(b"def", self.stop, max_bytes=5) + server.write(b"123456") + data = self.wait() + self.assertEqual(data, "closed") + finally: + server.close() + client.close() + + def test_read_until_regex_max_bytes_inline(self): + server, client = self.make_iostream_pair() + client.set_close_callback(lambda: self.stop("closed")) + try: + # Similar to the error case in the previous test, but the + # server writes first so client reads are satisfied + # inline. For consistency with the out-of-line case, we + # do not raise the error synchronously. + server.write(b"123456") + with ExpectLog(gen_log, "Unsatisfiable read"): + client.read_until_regex(b"def", self.stop, max_bytes=5) + data = self.wait() + self.assertEqual(data, "closed") + finally: + server.close() + client.close() + + def test_read_until_regex_max_bytes_ignores_extra(self): + server, client = self.make_iostream_pair() + client.set_close_callback(lambda: self.stop("closed")) + try: + # Even though data that matches arrives the same packet that + # puts us over the limit, we fail the request because it was not + # found within the limit. + server.write(b"abcdef") + with ExpectLog(gen_log, "Unsatisfiable read"): + client.read_until_regex(b"def", self.stop, max_bytes=5) + data = self.wait() + self.assertEqual(data, "closed") + finally: + server.close() + client.close() + + def test_small_reads_from_large_buffer(self): + # 10KB buffer size, 100KB available to read. + # Read 1KB at a time and make sure that the buffer is not eagerly + # filled. + server, client = self.make_iostream_pair(max_buffer_size=10 * 1024) + try: + server.write(b"a" * 1024 * 100) + for i in range(100): + client.read_bytes(1024, self.stop) + data = self.wait() + self.assertEqual(data, b"a" * 1024) + finally: + server.close() + client.close() + + def test_small_read_untils_from_large_buffer(self): + # 10KB buffer size, 100KB available to read. + # Read 1KB at a time and make sure that the buffer is not eagerly + # filled. + server, client = self.make_iostream_pair(max_buffer_size=10 * 1024) + try: + server.write((b"a" * 1023 + b"\n") * 100) + for i in range(100): + client.read_until(b"\n", self.stop, max_bytes=4096) + data = self.wait() + self.assertEqual(data, b"a" * 1023 + b"\n") + finally: + server.close() + client.close() + class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): def _make_client_iostream(self): @@ -472,14 +742,10 @@ class TestIOStream(TestIOStreamMixin, AsyncTestCase): class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase): def _make_server_iostream(self, connection, **kwargs): - ssl_options = dict( - certfile=os.path.join(os.path.dirname(__file__), 'test.crt'), - keyfile=os.path.join(os.path.dirname(__file__), 'test.key'), - ) connection = ssl.wrap_socket(connection, server_side=True, do_handshake_on_connect=False, - **ssl_options) + **_server_ssl_options()) return SSLIOStream(connection, io_loop=self.io_loop, **kwargs) def _make_client_iostream(self, connection, **kwargs): @@ -507,6 +773,91 @@ class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase): ssl_options=context, **kwargs) +class TestIOStreamStartTLS(AsyncTestCase): + def setUp(self): + try: + super(TestIOStreamStartTLS, self).setUp() + self.listener, self.port = bind_unused_port() + self.server_stream = None + self.server_accepted = Future() + netutil.add_accept_handler(self.listener, self.accept) + self.client_stream = IOStream(socket.socket()) + self.io_loop.add_future(self.client_stream.connect( + ('127.0.0.1', self.port)), self.stop) + self.wait() + self.io_loop.add_future(self.server_accepted, self.stop) + self.wait() + except Exception as e: + print(e) + raise + + def tearDown(self): + if self.server_stream is not None: + self.server_stream.close() + if self.client_stream is not None: + self.client_stream.close() + self.listener.close() + super(TestIOStreamStartTLS, self).tearDown() + + def accept(self, connection, address): + if self.server_stream is not None: + self.fail("should only get one connection") + self.server_stream = IOStream(connection) + self.server_accepted.set_result(None) + + @gen.coroutine + def client_send_line(self, line): + self.client_stream.write(line) + recv_line = yield self.server_stream.read_until(b"\r\n") + self.assertEqual(line, recv_line) + + @gen.coroutine + def server_send_line(self, line): + self.server_stream.write(line) + recv_line = yield self.client_stream.read_until(b"\r\n") + self.assertEqual(line, recv_line) + + def client_start_tls(self, ssl_options=None): + client_stream = self.client_stream + self.client_stream = None + return client_stream.start_tls(False, ssl_options) + + def server_start_tls(self, ssl_options=None): + server_stream = self.server_stream + self.server_stream = None + return server_stream.start_tls(True, ssl_options) + + @gen_test + def test_start_tls_smtp(self): + # This flow is simplified from RFC 3207 section 5. + # We don't really need all of this, but it helps to make sure + # that after realistic back-and-forth traffic the buffers end up + # in a sane state. + yield self.server_send_line(b"220 mail.example.com ready\r\n") + yield self.client_send_line(b"EHLO mail.example.com\r\n") + yield self.server_send_line(b"250-mail.example.com welcome\r\n") + yield self.server_send_line(b"250 STARTTLS\r\n") + yield self.client_send_line(b"STARTTLS\r\n") + yield self.server_send_line(b"220 Go ahead\r\n") + client_future = self.client_start_tls() + server_future = self.server_start_tls(_server_ssl_options()) + self.client_stream = yield client_future + self.server_stream = yield server_future + self.assertTrue(isinstance(self.client_stream, SSLIOStream)) + self.assertTrue(isinstance(self.server_stream, SSLIOStream)) + yield self.client_send_line(b"EHLO mail.example.com\r\n") + yield self.server_send_line(b"250 mail.example.com welcome\r\n") + + @gen_test + def test_handshake_fail(self): + self.server_start_tls(_server_ssl_options()) + client_future = self.client_start_tls( + dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where())) + with ExpectLog(gen_log, "SSL Error"): + with self.assertRaises(ssl.SSLError): + yield client_future + + @skipIfNonUnix class TestPipeIOStream(AsyncTestCase): def test_pipe_iostream(self): diff --git a/tornado/test/log_test.py b/tornado/test/log_test.py index d60cbad4..ee832c54 100644 --- a/tornado/test/log_test.py +++ b/tornado/test/log_test.py @@ -20,6 +20,8 @@ import glob import logging import os import re +import subprocess +import sys import tempfile import warnings @@ -156,3 +158,50 @@ class EnablePrettyLoggingTest(unittest.TestCase): for filename in glob.glob(tmpdir + '/test_log*'): os.unlink(filename) os.rmdir(tmpdir) + + +class LoggingOptionTest(unittest.TestCase): + """Test the ability to enable and disable Tornado's logging hooks.""" + def logs_present(self, statement, args=None): + # Each test may manipulate and/or parse the options and then logs + # a line at the 'info' level. This level is ignored in the + # logging module by default, but Tornado turns it on by default + # so it is the easiest way to tell whether tornado's logging hooks + # ran. + IMPORT = 'from tornado.options import options, parse_command_line' + LOG_INFO = 'import logging; logging.info("hello")' + program = ';'.join([IMPORT, statement, LOG_INFO]) + proc = subprocess.Popen( + [sys.executable, '-c', program] + (args or []), + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, stderr = proc.communicate() + self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout) + return b'hello' in stdout + + def test_default(self): + self.assertFalse(self.logs_present('pass')) + + def test_tornado_default(self): + self.assertTrue(self.logs_present('parse_command_line()')) + + def test_disable_command_line(self): + self.assertFalse(self.logs_present('parse_command_line()', + ['--logging=none'])) + + def test_disable_command_line_case_insensitive(self): + self.assertFalse(self.logs_present('parse_command_line()', + ['--logging=None'])) + + def test_disable_code_string(self): + self.assertFalse(self.logs_present( + 'options.logging = "none"; parse_command_line()')) + + def test_disable_code_none(self): + self.assertFalse(self.logs_present( + 'options.logging = None; parse_command_line()')) + + def test_disable_override(self): + # command line trumps code defaults + self.assertTrue(self.logs_present( + 'options.logging = None; parse_command_line()', + ['--logging=info'])) diff --git a/tornado/test/netutil_test.py b/tornado/test/netutil_test.py index ea8d51a5..94e5e4d2 100644 --- a/tornado/test/netutil_test.py +++ b/tornado/test/netutil_test.py @@ -1,15 +1,16 @@ from __future__ import absolute_import, division, print_function, with_statement +import os import signal import socket from subprocess import Popen import sys import time -from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip +from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets from tornado.stack_context import ExceptionStackContext from tornado.testing import AsyncTestCase, gen_test -from tornado.test.util import unittest +from tornado.test.util import unittest, skipIfNoNetwork try: from concurrent import futures @@ -25,6 +26,7 @@ else: try: import twisted + import twisted.names except ImportError: twisted = None else: @@ -73,12 +75,14 @@ class _ResolverTestMixin(object): socket.AF_UNSPEC) +@skipIfNoNetwork class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin): def setUp(self): super(BlockingResolverTest, self).setUp() self.resolver = BlockingResolver(io_loop=self.io_loop) +@skipIfNoNetwork @unittest.skipIf(futures is None, "futures module not present") class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin): def setUp(self): @@ -90,7 +94,9 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin): super(ThreadedResolverTest, self).tearDown() +@skipIfNoNetwork @unittest.skipIf(futures is None, "futures module not present") +@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32") class ThreadedResolverImportTest(unittest.TestCase): def test_import(self): TIMEOUT = 5 @@ -115,6 +121,7 @@ class ThreadedResolverImportTest(unittest.TestCase): self.fail("import timed out") +@skipIfNoNetwork @unittest.skipIf(pycares is None, "pycares module not present") class CaresResolverTest(AsyncTestCase, _ResolverTestMixin): def setUp(self): @@ -122,6 +129,7 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin): self.resolver = CaresResolver(io_loop=self.io_loop) +@skipIfNoNetwork @unittest.skipIf(twisted is None, "twisted module not present") @unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted") class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin): @@ -144,3 +152,17 @@ class IsValidIPTest(unittest.TestCase): self.assertTrue(not is_valid_ip(' ')) self.assertTrue(not is_valid_ip('\n')) self.assertTrue(not is_valid_ip('\x00')) + + +class TestPortAllocation(unittest.TestCase): + def test_same_port_allocation(self): + if 'TRAVIS' in os.environ: + self.skipTest("dual-stack servers often have port conflicts on travis") + sockets = bind_sockets(None, 'localhost') + try: + port = sockets[0].getsockname()[1] + self.assertTrue(all(s.getsockname()[1] == port + for s in sockets[1:])) + finally: + for sock in sockets: + sock.close() diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index 37efc205..a80b80b9 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -40,6 +40,7 @@ TEST_MODULES = [ 'tornado.test.process_test', 'tornado.test.simple_httpclient_test', 'tornado.test.stack_context_test', + 'tornado.test.tcpclient_test', 'tornado.test.template_test', 'tornado.test.testing_test', 'tornado.test.twisted_test', @@ -65,7 +66,8 @@ class TornadoTextTestRunner(unittest.TextTestRunner): self.stream.write("\n") return result -if __name__ == '__main__': + +def main(): # The -W command-line option does not work in a virtualenv with # python 3 (as of virtualenv 1.7), so configure warnings # programmatically instead. @@ -82,6 +84,9 @@ if __name__ == '__main__': warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("error", category=DeprecationWarning, module=r"tornado\..*") + warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + warnings.filterwarnings("error", category=PendingDeprecationWarning, + module=r"tornado\..*") # The unittest module is aggressive about deprecating redundant methods, # leaving some without non-deprecated spellings that work on both # 2.7 and 3.2 @@ -127,3 +132,6 @@ if __name__ == '__main__': kwargs['warnings'] = False kwargs['testRunner'] = TornadoTextTestRunner tornado.testing.main(**kwargs) + +if __name__ == '__main__': + main() diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index ac98aaae..2ba9f75d 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -10,17 +10,18 @@ import re import socket import sys +from tornado import gen from tornado.httpclient import AsyncHTTPClient from tornado.httputil import HTTPHeaders from tornado.ioloop import IOLoop -from tornado.log import gen_log -from tornado.netutil import Resolver -from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS +from tornado.log import gen_log, app_log +from tornado.netutil import Resolver, bind_sockets +from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler from tornado.test import httpclient_test from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog -from tornado.test.util import unittest, skipOnTravis -from tornado.web import RequestHandler, Application, asynchronous, url +from tornado.test.util import skipOnTravis, skipIfNoIPv6 +from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase): @@ -70,7 +71,8 @@ class OptionsHandler(RequestHandler): class NoContentHandler(RequestHandler): def get(self): if self.get_argument("error", None): - self.set_header("Content-Length", "7") + self.set_header("Content-Length", "5") + self.write("hello") self.set_status(204) @@ -94,6 +96,30 @@ class HostEchoHandler(RequestHandler): self.write(self.request.headers["Host"]) +class NoContentLengthHandler(RequestHandler): + @gen.coroutine + def get(self): + # Emulate the old HTTP/1.0 behavior of returning a body with no + # content-length. Tornado handles content-length at the framework + # level so we have to go around it. + stream = self.request.connection.stream + yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n" + b"hello") + stream.close() + + +class EchoPostHandler(RequestHandler): + def post(self): + self.write(self.request.body) + + +@stream_request_body +class RespondInPrepareHandler(RequestHandler): + def prepare(self): + self.set_status(403) + self.finish("forbidden") + + class SimpleHTTPClientTestMixin(object): def get_app(self): # callable objects to finish pending /trigger requests @@ -112,6 +138,9 @@ class SimpleHTTPClientTestMixin(object): url("/see_other_post", SeeOtherPostHandler), url("/see_other_get", SeeOtherGetHandler), url("/host_echo", HostEchoHandler), + url("/no_content_length", NoContentLengthHandler), + url("/echo_post", EchoPostHandler), + url("/respond_in_prepare", RespondInPrepareHandler), ], gzip=True) def test_singleton(self): @@ -163,7 +192,7 @@ class SimpleHTTPClientTestMixin(object): response.rethrow() def test_default_certificates_exist(self): - open(_DEFAULT_CA_CERTS).close() + open(_default_ca_certs()).close() def test_gzip(self): # All the tests in this file should be using gzip, but this test @@ -213,28 +242,30 @@ class SimpleHTTPClientTestMixin(object): # trigger the hanging request to let it clean up after itself self.triggers.popleft()() - @unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') + @skipIfNoIPv6 def test_ipv6(self): try: - self.http_server.listen(self.get_http_port(), address='::1') + [sock] = bind_sockets(None, '::1', family=socket.AF_INET6) + port = sock.getsockname()[1] + self.http_server.add_socket(sock) except socket.gaierror as e: if e.args[0] == socket.EAI_ADDRFAMILY: # python supports ipv6, but it's not configured on the network # interface, so skip this test. return raise - url = self.get_url("/hello").replace("localhost", "[::1]") + url = '%s://[::1]:%d/hello' % (self.get_protocol(), port) - # ipv6 is currently disabled by default and must be explicitly requested - self.http_client.fetch(url, self.stop) + # ipv6 is currently enabled by default but can be disabled + self.http_client.fetch(url, self.stop, allow_ipv6=False) response = self.wait() self.assertEqual(response.code, 599) - self.http_client.fetch(url, self.stop, allow_ipv6=True) + self.http_client.fetch(url, self.stop) response = self.wait() self.assertEqual(response.body, b"Hello world!") - def test_multiple_content_length_accepted(self): + def xtest_multiple_content_length_accepted(self): response = self.fetch("/content_length?value=2,2") self.assertEqual(response.body, b"ok") response = self.fetch("/content_length?value=2,%202,2") @@ -266,7 +297,8 @@ class SimpleHTTPClientTestMixin(object): self.assertEqual(response.headers["Content-length"], "0") # 204 status with non-zero content length is malformed - response = self.fetch("/no_content?error=1") + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/no_content?error=1") self.assertEqual(response.code, 599) def test_host_header(self): @@ -313,6 +345,60 @@ class SimpleHTTPClientTestMixin(object): self.triggers.popleft()() self.wait() + def test_no_content_length(self): + response = self.fetch("/no_content_length") + self.assertEquals(b"hello", response.body) + + def sync_body_producer(self, write): + write(b'1234') + write(b'5678') + + @gen.coroutine + def async_body_producer(self, write): + yield write(b'1234') + yield gen.Task(IOLoop.current().add_callback) + yield write(b'5678') + + def test_sync_body_producer_chunked(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.sync_body_producer) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_sync_body_producer_content_length(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.sync_body_producer, + headers={'Content-Length': '8'}) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_async_body_producer_chunked(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.async_body_producer) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_async_body_producer_content_length(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.async_body_producer, + headers={'Content-Length': '8'}) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_100_continue(self): + response = self.fetch("/echo_post", method="POST", + body=b"1234", + expect_100_continue=True) + self.assertEqual(response.body, b"1234") + + def test_100_continue_early_response(self): + def body_producer(write): + raise Exception("should not be called") + response = self.fetch("/respond_in_prepare", method="POST", + body_producer=body_producer, + expect_100_continue=True) + self.assertEqual(response.code, 403) + class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase): def setUp(self): @@ -433,3 +519,32 @@ class ResolveTimeoutTestCase(AsyncHTTPTestCase): def test_resolve_timeout(self): response = self.fetch('/hello', connect_timeout=0.1) self.assertEqual(response.code, 599) + + +class MaxHeaderSizeTest(AsyncHTTPTestCase): + def get_app(self): + class SmallHeaders(RequestHandler): + def get(self): + self.set_header("X-Filler", "a" * 100) + self.write("ok") + + class LargeHeaders(RequestHandler): + def get(self): + self.set_header("X-Filler", "a" * 1000) + self.write("ok") + + return Application([('/small', SmallHeaders), + ('/large', LargeHeaders)]) + + def get_http_client(self): + return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024) + + def test_small_headers(self): + response = self.fetch('/small') + response.rethrow() + self.assertEqual(response.body, b'ok') + + def test_large_headers(self): + with ExpectLog(gen_log, "Unsatisfiable read"): + response = self.fetch('/large') + self.assertEqual(response.code, 599) diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py index 29193305..d65a5b21 100644 --- a/tornado/test/stack_context_test.py +++ b/tornado/test/stack_context_test.py @@ -219,22 +219,13 @@ class StackContextTest(AsyncTestCase): def test_yield_in_with(self): @gen.engine def f(): - try: - self.callback = yield gen.Callback('a') - with StackContext(functools.partial(self.context, 'c1')): - # This yield is a problem: the generator will be suspended - # and the StackContext's __exit__ is not called yet, so - # the context will be left on _state.contexts for anything - # that runs before the yield resolves. - yield gen.Wait('a') - except StackContextInconsistentError: - # In python <= 3.3, this suspended generator is never garbage - # collected, so it remains suspended in the 'yield' forever. - # Starting in 3.4, it is made collectable by raising - # a GeneratorExit exception from the yield, which gets - # converted into a StackContextInconsistentError by the - # exit of the 'with' block. - pass + self.callback = yield gen.Callback('a') + with StackContext(functools.partial(self.context, 'c1')): + # This yield is a problem: the generator will be suspended + # and the StackContext's __exit__ is not called yet, so + # the context will be left on _state.contexts for anything + # that runs before the yield resolves. + yield gen.Wait('a') with self.assertRaises(StackContextInconsistentError): f() @@ -257,11 +248,8 @@ class StackContextTest(AsyncTestCase): # As above, but with ExceptionStackContext instead of StackContext. @gen.engine def f(): - try: - with ExceptionStackContext(lambda t, v, tb: False): - yield gen.Task(self.io_loop.add_callback) - except StackContextInconsistentError: - pass + with ExceptionStackContext(lambda t, v, tb: False): + yield gen.Task(self.io_loop.add_callback) with self.assertRaises(StackContextInconsistentError): f() diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py new file mode 100644 index 00000000..a9dfe5a5 --- /dev/null +++ b/tornado/test/tcpclient_test.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +from contextlib import closing +import os +import socket + +from tornado.concurrent import Future +from tornado.netutil import bind_sockets, Resolver +from tornado.tcpclient import TCPClient, _Connector +from tornado.tcpserver import TCPServer +from tornado.testing import AsyncTestCase, bind_unused_port, gen_test +from tornado.test.util import skipIfNoIPv6, unittest + +# Fake address families for testing. Used in place of AF_INET +# and AF_INET6 because some installations do not have AF_INET6. +AF1, AF2 = 1, 2 + + +class TestTCPServer(TCPServer): + def __init__(self, family): + super(TestTCPServer, self).__init__() + self.streams = [] + sockets = bind_sockets(None, 'localhost', family) + self.add_sockets(sockets) + self.port = sockets[0].getsockname()[1] + + def handle_stream(self, stream, address): + self.streams.append(stream) + + def stop(self): + super(TestTCPServer, self).stop() + for stream in self.streams: + stream.close() + + +class TCPClientTest(AsyncTestCase): + def setUp(self): + super(TCPClientTest, self).setUp() + self.server = None + self.client = TCPClient() + + def start_server(self, family): + if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ: + self.skipTest("dual-stack servers often have port conflicts on travis") + self.server = TestTCPServer(family) + return self.server.port + + def stop_server(self): + if self.server is not None: + self.server.stop() + self.server = None + + def tearDown(self): + self.client.close() + self.stop_server() + super(TCPClientTest, self).tearDown() + + def skipIfLocalhostV4(self): + Resolver().resolve('localhost', 0, callback=self.stop) + addrinfo = self.wait() + families = set(addr[0] for addr in addrinfo) + if socket.AF_INET6 not in families: + self.skipTest("localhost does not resolve to ipv6") + + @gen_test + def do_test_connect(self, family, host): + port = self.start_server(family) + stream = yield self.client.connect(host, port) + with closing(stream): + stream.write(b"hello") + data = yield self.server.streams[0].read_bytes(5) + self.assertEqual(data, b"hello") + + def test_connect_ipv4_ipv4(self): + self.do_test_connect(socket.AF_INET, '127.0.0.1') + + def test_connect_ipv4_dual(self): + self.do_test_connect(socket.AF_INET, 'localhost') + + @skipIfNoIPv6 + def test_connect_ipv6_ipv6(self): + self.skipIfLocalhostV4() + self.do_test_connect(socket.AF_INET6, '::1') + + @skipIfNoIPv6 + def test_connect_ipv6_dual(self): + self.skipIfLocalhostV4() + if Resolver.configured_class().__name__.endswith('TwistedResolver'): + self.skipTest('TwistedResolver does not support multiple addresses') + self.do_test_connect(socket.AF_INET6, 'localhost') + + def test_connect_unspec_ipv4(self): + self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1') + + @skipIfNoIPv6 + def test_connect_unspec_ipv6(self): + self.skipIfLocalhostV4() + self.do_test_connect(socket.AF_UNSPEC, '::1') + + def test_connect_unspec_dual(self): + self.do_test_connect(socket.AF_UNSPEC, 'localhost') + + @gen_test + def test_refused_ipv4(self): + sock, port = bind_unused_port() + sock.close() + with self.assertRaises(IOError): + yield self.client.connect('127.0.0.1', port) + + +class TestConnectorSplit(unittest.TestCase): + def test_one_family(self): + # These addresses aren't in the right format, but split doesn't care. + primary, secondary = _Connector.split( + [(AF1, 'a'), + (AF1, 'b')]) + self.assertEqual(primary, [(AF1, 'a'), + (AF1, 'b')]) + self.assertEqual(secondary, []) + + def test_mixed(self): + primary, secondary = _Connector.split( + [(AF1, 'a'), + (AF2, 'b'), + (AF1, 'c'), + (AF2, 'd')]) + self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')]) + self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')]) + + +class ConnectorTest(AsyncTestCase): + class FakeStream(object): + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def setUp(self): + super(ConnectorTest, self).setUp() + self.connect_futures = {} + self.streams = {} + self.addrinfo = [(AF1, 'a'), (AF1, 'b'), + (AF2, 'c'), (AF2, 'd')] + + def tearDown(self): + # Unless explicitly checked (and popped) in the test, we shouldn't + # be closing any streams + for stream in self.streams.values(): + self.assertFalse(stream.closed) + super(ConnectorTest, self).tearDown() + + def create_stream(self, af, addr): + future = Future() + self.connect_futures[(af, addr)] = future + return future + + def assert_pending(self, *keys): + self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys)) + + def resolve_connect(self, af, addr, success): + future = self.connect_futures.pop((af, addr)) + if success: + self.streams[addr] = ConnectorTest.FakeStream() + future.set_result(self.streams[addr]) + else: + future.set_exception(IOError()) + + def start_connect(self, addrinfo): + conn = _Connector(addrinfo, self.io_loop, self.create_stream) + # Give it a huge timeout; we'll trigger timeouts manually. + future = conn.start(3600) + return conn, future + + def test_immediate_success(self): + conn, future = self.start_connect(self.addrinfo) + self.assertEqual(list(self.connect_futures.keys()), + [(AF1, 'a')]) + self.resolve_connect(AF1, 'a', True) + self.assertEqual(future.result(), (AF1, 'a', self.streams['a'])) + + def test_immediate_failure(self): + # Fail with just one address. + conn, future = self.start_connect([(AF1, 'a')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assertRaises(IOError, future.result) + + def test_one_family_second_try(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', True) + self.assertEqual(future.result(), (AF1, 'b', self.streams['b'])) + + def test_one_family_second_try_failure(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', False) + self.assertRaises(IOError, future.result) + + def test_one_family_second_try_timeout(self): + conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')]) + self.assert_pending((AF1, 'a')) + # trigger the timeout while the first lookup is pending; + # nothing happens. + conn.on_timeout() + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.resolve_connect(AF1, 'b', True) + self.assertEqual(future.result(), (AF1, 'b', self.streams['b'])) + + def test_two_families_immediate_failure(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b'), (AF2, 'c')) + self.resolve_connect(AF1, 'b', False) + self.resolve_connect(AF2, 'c', True) + self.assertEqual(future.result(), (AF2, 'c', self.streams['c'])) + + def test_two_families_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF2, 'c', True) + self.assertEqual(future.result(), (AF2, 'c', self.streams['c'])) + # resolving 'a' after the connection has completed doesn't start 'b' + self.resolve_connect(AF1, 'a', False) + self.assert_pending() + + def test_success_after_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF1, 'a', True) + self.assertEqual(future.result(), (AF1, 'a', self.streams['a'])) + # resolving 'c' after completion closes the connection. + self.resolve_connect(AF2, 'c', True) + self.assertTrue(self.streams.pop('c').closed) + + def test_all_fail(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, 'a')) + conn.on_timeout() + self.assert_pending((AF1, 'a'), (AF2, 'c')) + self.resolve_connect(AF2, 'c', False) + self.assert_pending((AF1, 'a'), (AF2, 'd')) + self.resolve_connect(AF2, 'd', False) + # one queue is now empty + self.assert_pending((AF1, 'a')) + self.resolve_connect(AF1, 'a', False) + self.assert_pending((AF1, 'b')) + self.assertFalse(future.done()) + self.resolve_connect(AF1, 'b', False) + self.assertRaises(IOError, future.result) diff --git a/tornado/test/template_test.py b/tornado/test/template_test.py index f3a9e059..6d8b624e 100644 --- a/tornado/test/template_test.py +++ b/tornado/test/template_test.py @@ -182,6 +182,7 @@ three """}) try: loader.load("test.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: self.assertTrue("# test.html:2" in traceback.format_exc()) @@ -192,6 +193,7 @@ three{%end%} """}) try: loader.load("test.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: self.assertTrue("# test.html:2" in traceback.format_exc()) @@ -202,6 +204,7 @@ three{%end%} }, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})}) try: loader.load("base.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: exc_stack = traceback.format_exc() self.assertTrue('# base.html:1' in exc_stack) @@ -214,6 +217,7 @@ three{%end%} }) try: loader.load("base.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: self.assertTrue("# sub.html:1 (via base.html:1)" in traceback.format_exc()) @@ -225,6 +229,7 @@ three{%end%} }) try: loader.load("sub.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: exc_stack = traceback.format_exc() self.assertTrue("# base.html:1" in exc_stack) @@ -240,6 +245,7 @@ three{%end%} """}) try: loader.load("sub.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: self.assertTrue("# sub.html:4 (via base.html:1)" in traceback.format_exc()) @@ -252,6 +258,7 @@ three{%end%} }) try: loader.load("a.html").generate() + self.fail("did not get expected exception") except ZeroDivisionError: self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in traceback.format_exc()) @@ -380,6 +387,20 @@ raw: {% raw name %}""", self.assertEqual(render("foo.py", ["not a string"]), b"""s = "['not a string']"\n""") + def test_minimize_whitespace(self): + # Whitespace including newlines is allowed within template tags + # and directives, and this is one way to avoid long lines while + # keeping extra whitespace out of the rendered output. + loader = DictLoader({'foo.txt': """\ +{% for i in items + %}{% if i > 0 %}, {% end %}{# + #}{{i + }}{% end +%}""", + }) + self.assertEqual(loader.load("foo.txt").generate(items=range(5)), + b"0, 1, 2, 3, 4") + class TemplateLoaderTest(unittest.TestCase): def setUp(self): diff --git a/tornado/test/testing_test.py b/tornado/test/testing_test.py index 64e5683e..1c8a8650 100644 --- a/tornado/test/testing_test.py +++ b/tornado/test/testing_test.py @@ -8,6 +8,7 @@ from tornado.test.util import unittest import contextlib import os +import traceback @contextlib.contextmanager @@ -62,6 +63,39 @@ class AsyncTestCaseTest(AsyncTestCase): self.wait(timeout=0.15) +class AsyncTestCaseWrapperTest(unittest.TestCase): + def test_undecorated_generator(self): + class Test(AsyncTestCase): + def test_gen(self): + yield + test = Test('test_gen') + result = unittest.TestResult() + test.run(result) + self.assertEqual(len(result.errors), 1) + self.assertIn("should be decorated", result.errors[0][1]) + + def test_undecorated_generator_with_skip(self): + class Test(AsyncTestCase): + @unittest.skip("don't run this") + def test_gen(self): + yield + test = Test('test_gen') + result = unittest.TestResult() + test.run(result) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.skipped), 1) + + def test_other_return(self): + class Test(AsyncTestCase): + def test_other_return(self): + return 42 + test = Test('test_other_return') + result = unittest.TestResult() + test.run(result) + self.assertEqual(len(result.errors), 1) + self.assertIn("Return value from test method ignored", result.errors[0][1]) + + class SetUpTearDownTest(unittest.TestCase): def test_set_up_tear_down(self): """ @@ -115,8 +149,17 @@ class GenTest(AsyncTestCase): def test(self): yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1) - with self.assertRaises(ioloop.TimeoutError): + # This can't use assertRaises because we need to inspect the + # exc_info triple (and not just the exception object) + try: test(self) + self.fail("did not get expected exception") + except ioloop.TimeoutError: + # The stack trace should blame the add_timeout line, not just + # unrelated IOLoop/testing internals. + self.assertIn( + "gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)", + traceback.format_exc()) self.finished = True @@ -155,5 +198,23 @@ class GenTest(AsyncTestCase): self.finished = True + def test_with_method_args(self): + @gen_test + def test_with_args(self, *args): + self.assertEqual(args, ('test',)) + yield gen.Task(self.io_loop.add_callback) + + test_with_args(self, 'test') + self.finished = True + + def test_with_method_kwargs(self): + @gen_test + def test_with_kwargs(self, **kwargs): + self.assertDictEqual(kwargs, {'test': 'test'}) + yield gen.Task(self.io_loop.add_callback) + + test_with_kwargs(self, test='test') + self.finished = True + if __name__ == '__main__': unittest.main() diff --git a/tornado/test/util.py b/tornado/test/util.py index 36043104..d31bbba3 100644 --- a/tornado/test/util.py +++ b/tornado/test/util.py @@ -1,14 +1,18 @@ from __future__ import absolute_import, division, print_function, with_statement import os +import socket import sys # Encapsulate the choice of unittest or unittest2 here. # To be used as 'from tornado.test.util import unittest'. -if sys.version_info >= (2, 7): - import unittest -else: +if sys.version_info < (2, 7): + # In py26, we must always use unittest2. import unittest2 as unittest +else: + # Otherwise, use whichever version of unittest was imported in + # tornado.testing. + from tornado.testing import unittest skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin', "non-unix platform") @@ -17,3 +21,10 @@ skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin', # timing-related tests unreliable. skipOnTravis = unittest.skipIf('TRAVIS' in os.environ, 'timing tests unreliable on travis') + +# Set the environment variable NO_NETWORK=1 to disable any tests that +# depend on an external network. +skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ, + 'network access disabled') + +skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') diff --git a/tornado/test/util_test.py b/tornado/test/util_test.py index 5df54f5e..41ccbb9a 100644 --- a/tornado/test/util_test.py +++ b/tornado/test/util_test.py @@ -151,14 +151,22 @@ class ArgReplacerTest(unittest.TestCase): self.replacer = ArgReplacer(function, 'callback') def test_omitted(self): - self.assertEqual(self.replacer.replace('new', (1, 2), dict()), + args = (1, 2) + kwargs = dict() + self.assertIs(self.replacer.get_old_value(args, kwargs), None) + self.assertEqual(self.replacer.replace('new', args, kwargs), (None, (1, 2), dict(callback='new'))) def test_position(self): - self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()), + args = (1, 2, 'old', 3) + kwargs = dict() + self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old') + self.assertEqual(self.replacer.replace('new', args, kwargs), ('old', [1, 2, 'new', 3], dict())) def test_keyword(self): - self.assertEqual(self.replacer.replace('new', (1,), - dict(y=2, callback='old', z=3)), + args = (1,) + kwargs = dict(y=2, callback='old', z=3) + self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old') + self.assertEqual(self.replacer.replace('new', args, kwargs), ('old', (1,), dict(y=2, callback='new', z=3))) diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index c475520b..cbb62b9b 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, division, print_function, with_statement +from tornado.concurrent import Future from tornado import gen from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring from tornado.httputil import format_timestamp @@ -6,14 +7,16 @@ from tornado.iostream import IOStream from tornado.log import app_log, gen_log from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.template import DictLoader -from tornado.testing import AsyncHTTPTestCase, ExpectLog +from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test from tornado.test.util import unittest from tornado.util import u, bytes_type, ObjectDict, unicode_type -from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError +from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body import binascii +import contextlib import datetime import email.utils +import itertools import logging import os import re @@ -100,14 +103,14 @@ class SecureCookieV1Test(unittest.TestCase): sig = match.group(2) self.assertEqual( _create_signature_v1(handler.application.settings["cookie_secret"], - 'foo', '12345678', timestamp), + 'foo', '12345678', timestamp), sig) # shifting digits from payload to timestamp doesn't alter signature # (this is not desirable behavior, just confirming that that's how it # works) self.assertEqual( _create_signature_v1(handler.application.settings["cookie_secret"], - 'foo', '1234', b'5678' + timestamp), + 'foo', '1234', b'5678' + timestamp), sig) # tamper with the cookie handler._cookies['foo'] = utf8('1234|5678%s|%s' % ( @@ -471,12 +474,13 @@ class EmptyFlushCallbackHandler(RequestHandler): @asynchronous def get(self): # Ensure that the flush callback is run whether or not there - # was any output. + # was any output. The gen.Task and direct yield forms are + # equivalent. yield gen.Task(self.flush) # "empty" flush, but writes headers yield gen.Task(self.flush) # empty flush self.write("o") - yield gen.Task(self.flush) # flushes the "o" - yield gen.Task(self.flush) # empty flush + yield self.flush() # flushes the "o" + yield self.flush() # empty flush self.finish("k") @@ -575,8 +579,8 @@ class WSGISafeWebTest(WebTestCase): "/decode_arg/%E9?foo=%E9&encoding=latin1", "/decode_arg_kw/%E9?foo=%E9&encoding=latin1", ] - for url in urls: - response = self.fetch(url) + for req_url in urls: + response = self.fetch(req_url) response.rethrow() data = json_decode(response.body) self.assertEqual(data, {u('path'): [u('unicode'), u('\u00e9')], @@ -602,8 +606,8 @@ class WSGISafeWebTest(WebTestCase): # These urls are all equivalent. urls = ["/decode_arg/1%20%2B%201?foo=1%20%2B%201&encoding=utf-8", "/decode_arg/1%20+%201?foo=1+%2B+1&encoding=utf-8"] - for url in urls: - response = self.fetch(url) + for req_url in urls: + response = self.fetch(req_url) response.rethrow() data = json_decode(response.body) self.assertEqual(data, {u('path'): [u('unicode'), u('1 + 1')], @@ -915,17 +919,37 @@ class StaticFileTest(WebTestCase): response = self.fetch(path % int(include_host)) self.assertEqual(response.body, utf8(str(True))) + def get_and_head(self, *args, **kwargs): + """Performs a GET and HEAD request and returns the GET response. + + Fails if any ``Content-*`` headers returned by the two requests + differ. + """ + head_response = self.fetch(*args, method="HEAD", **kwargs) + get_response = self.fetch(*args, method="GET", **kwargs) + content_headers = set() + for h in itertools.chain(head_response.headers, get_response.headers): + if h.startswith('Content-'): + content_headers.add(h) + for h in content_headers: + self.assertEqual(head_response.headers.get(h), + get_response.headers.get(h), + "%s differs between GET (%s) and HEAD (%s)" % + (h, head_response.headers.get(h), + get_response.headers.get(h))) + return get_response + def test_static_304_if_modified_since(self): - response1 = self.fetch("/static/robots.txt") - response2 = self.fetch("/static/robots.txt", headers={ + response1 = self.get_and_head("/static/robots.txt") + response2 = self.get_and_head("/static/robots.txt", headers={ 'If-Modified-Since': response1.headers['Last-Modified']}) self.assertEqual(response2.code, 304) self.assertTrue('Content-Length' not in response2.headers) self.assertTrue('Last-Modified' not in response2.headers) def test_static_304_if_none_match(self): - response1 = self.fetch("/static/robots.txt") - response2 = self.fetch("/static/robots.txt", headers={ + response1 = self.get_and_head("/static/robots.txt") + response2 = self.get_and_head("/static/robots.txt", headers={ 'If-None-Match': response1.headers['Etag']}) self.assertEqual(response2.code, 304) @@ -933,7 +957,7 @@ class StaticFileTest(WebTestCase): # On windows, the functions that work with time_t do not accept # negative values, and at least one client (processing.js) seems # to use if-modified-since 1/1/1960 as a cache-busting technique. - response = self.fetch("/static/robots.txt", headers={ + response = self.get_and_head("/static/robots.txt", headers={ 'If-Modified-Since': 'Fri, 01 Jan 1960 00:00:00 GMT'}) self.assertEqual(response.code, 200) @@ -944,20 +968,20 @@ class StaticFileTest(WebTestCase): # when parsing If-Modified-Since. stat = os.stat(relpath('static/robots.txt')) - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'If-Modified-Since': format_timestamp(stat.st_mtime - 1)}) self.assertEqual(response.code, 200) - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'If-Modified-Since': format_timestamp(stat.st_mtime + 1)}) self.assertEqual(response.code, 304) def test_static_etag(self): - response = self.fetch('/static/robots.txt') + response = self.get_and_head('/static/robots.txt') self.assertEqual(utf8(response.headers.get("Etag")), b'"' + self.robots_txt_hash + b'"') def test_static_with_range(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=0-9'}) self.assertEqual(response.code, 206) self.assertEqual(response.body, b"User-agent") @@ -968,7 +992,7 @@ class StaticFileTest(WebTestCase): "bytes 0-9/26") def test_static_with_range_full_file(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=0-'}) # Note: Chrome refuses to play audio if it gets an HTTP 206 in response # to ``Range: bytes=0-`` :( @@ -980,7 +1004,7 @@ class StaticFileTest(WebTestCase): self.assertEqual(response.headers.get("Content-Range"), None) def test_static_with_range_full_past_end(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=0-10000000'}) self.assertEqual(response.code, 200) robots_file_path = os.path.join(self.static_dir, "robots.txt") @@ -990,7 +1014,7 @@ class StaticFileTest(WebTestCase): self.assertEqual(response.headers.get("Content-Range"), None) def test_static_with_range_partial_past_end(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=1-10000000'}) self.assertEqual(response.code, 206) robots_file_path = os.path.join(self.static_dir, "robots.txt") @@ -1000,7 +1024,7 @@ class StaticFileTest(WebTestCase): self.assertEqual(response.headers.get("Content-Range"), "bytes 1-25/26") def test_static_with_range_end_edge(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=22-'}) self.assertEqual(response.body, b": /\n") self.assertEqual(response.headers.get("Content-Length"), "4") @@ -1008,7 +1032,7 @@ class StaticFileTest(WebTestCase): "bytes 22-25/26") def test_static_with_range_neg_end(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=-4'}) self.assertEqual(response.body, b": /\n") self.assertEqual(response.headers.get("Content-Length"), "4") @@ -1016,19 +1040,19 @@ class StaticFileTest(WebTestCase): "bytes 22-25/26") def test_static_invalid_range(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'asdf'}) self.assertEqual(response.code, 200) def test_static_unsatisfiable_range_zero_suffix(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=-0'}) self.assertEqual(response.headers.get("Content-Range"), "bytes */26") self.assertEqual(response.code, 416) def test_static_unsatisfiable_range_invalid_start(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=26'}) self.assertEqual(response.code, 416) self.assertEqual(response.headers.get("Content-Range"), @@ -1053,7 +1077,7 @@ class StaticFileTest(WebTestCase): b'"' + self.robots_txt_hash + b'"') def test_static_range_if_none_match(self): - response = self.fetch('/static/robots.txt', headers={ + response = self.get_and_head('/static/robots.txt', headers={ 'Range': 'bytes=1-4', 'If-None-Match': b'"' + self.robots_txt_hash + b'"'}) self.assertEqual(response.code, 304) @@ -1063,7 +1087,7 @@ class StaticFileTest(WebTestCase): b'"' + self.robots_txt_hash + b'"') def test_static_404(self): - response = self.fetch('/static/blarg') + response = self.get_and_head('/static/blarg') self.assertEqual(response.code, 404) @@ -1136,6 +1160,11 @@ class CustomStaticFileTest(WebTestCase): return b'bar' raise Exception("unexpected path %r" % path) + def get_content_size(self): + if self.absolute_path == 'CustomStaticFileTest:foo.txt': + return 3 + raise Exception("unexpected path %r" % self.absolute_path) + def get_modified_time(self): return None @@ -1335,6 +1364,7 @@ class ErrorHandlerXSRFTest(WebTestCase): self.assertEqual(response.code, 404) +@wsgi_safe class GzipTestCase(SimpleHandlerTestCase): class Handler(RequestHandler): def get(self): @@ -1347,7 +1377,13 @@ class GzipTestCase(SimpleHandlerTestCase): def test_gzip(self): response = self.fetch('/') - self.assertEqual(response.headers['Content-Encoding'], 'gzip') + # simple_httpclient renames the content-encoding header; + # curl_httpclient doesn't. + self.assertEqual( + response.headers.get( + 'Content-Encoding', + response.headers.get('X-Consumed-Content-Encoding')), + 'gzip') self.assertEqual(response.headers['Vary'], 'Accept-Encoding') def test_gzip_not_requested(self): @@ -1799,6 +1835,227 @@ class HandlerByNameTest(WebTestCase): self.assertEqual(resp.body, b'hello') +class StreamingRequestBodyTest(WebTestCase): + def get_handlers(self): + @stream_request_body + class StreamingBodyHandler(RequestHandler): + def initialize(self, test): + self.test = test + + def prepare(self): + self.test.prepared.set_result(None) + + def data_received(self, data): + self.test.data.set_result(data) + + def get(self): + self.test.finished.set_result(None) + self.write({}) + + @stream_request_body + class EarlyReturnHandler(RequestHandler): + def prepare(self): + # If we finish the response in prepare, it won't continue to + # the (non-existent) data_received. + raise HTTPError(401) + + @stream_request_body + class CloseDetectionHandler(RequestHandler): + def initialize(self, test): + self.test = test + + def on_connection_close(self): + super(CloseDetectionHandler, self).on_connection_close() + self.test.close_future.set_result(None) + + return [('/stream_body', StreamingBodyHandler, dict(test=self)), + ('/early_return', EarlyReturnHandler), + ('/close_detection', CloseDetectionHandler, dict(test=self))] + + def connect(self, url, connection_close): + # Use a raw connection so we can control the sending of data. + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + s.connect(("localhost", self.get_http_port())) + stream = IOStream(s, io_loop=self.io_loop) + stream.write(b"GET " + url + b" HTTP/1.1\r\n") + if connection_close: + stream.write(b"Connection: close\r\n") + stream.write(b"Transfer-Encoding: chunked\r\n\r\n") + return stream + + @gen_test + def test_streaming_body(self): + self.prepared = Future() + self.data = Future() + self.finished = Future() + + stream = self.connect(b"/stream_body", connection_close=True) + yield self.prepared + stream.write(b"4\r\nasdf\r\n") + # Ensure the first chunk is received before we send the second. + data = yield self.data + self.assertEqual(data, b"asdf") + self.data = Future() + stream.write(b"4\r\nqwer\r\n") + data = yield self.data + self.assertEquals(data, b"qwer") + stream.write(b"0\r\n") + yield self.finished + data = yield gen.Task(stream.read_until_close) + # This would ideally use an HTTP1Connection to read the response. + self.assertTrue(data.endswith(b"{}")) + stream.close() + + @gen_test + def test_early_return(self): + stream = self.connect(b"/early_return", connection_close=False) + data = yield gen.Task(stream.read_until_close) + self.assertTrue(data.startswith(b"HTTP/1.1 401")) + + @gen_test + def test_early_return_with_data(self): + stream = self.connect(b"/early_return", connection_close=False) + stream.write(b"4\r\nasdf\r\n") + data = yield gen.Task(stream.read_until_close) + self.assertTrue(data.startswith(b"HTTP/1.1 401")) + + @gen_test + def test_close_during_upload(self): + self.close_future = Future() + stream = self.connect(b"/close_detection", connection_close=False) + stream.close() + yield self.close_future + + +class StreamingRequestFlowControlTest(WebTestCase): + def get_handlers(self): + from tornado.ioloop import IOLoop + + # Each method in this handler returns a Future and yields to the + # IOLoop so the future is not immediately ready. Ensure that the + # Futures are respected and no method is called before the previous + # one has completed. + @stream_request_body + class FlowControlHandler(RequestHandler): + def initialize(self, test): + self.test = test + self.method = None + self.methods = [] + + @contextlib.contextmanager + def in_method(self, method): + if self.method is not None: + self.test.fail("entered method %s while in %s" % + (method, self.method)) + self.method = method + self.methods.append(method) + try: + yield + finally: + self.method = None + + @gen.coroutine + def prepare(self): + with self.in_method('prepare'): + yield gen.Task(IOLoop.current().add_callback) + + @gen.coroutine + def data_received(self, data): + with self.in_method('data_received'): + yield gen.Task(IOLoop.current().add_callback) + + @gen.coroutine + def post(self): + with self.in_method('post'): + yield gen.Task(IOLoop.current().add_callback) + self.write(dict(methods=self.methods)) + + return [('/', FlowControlHandler, dict(test=self))] + + def get_httpserver_options(self): + # Use a small chunk size so flow control is relevant even though + # all the data arrives at once. + return dict(chunk_size=10) + + def test_flow_control(self): + response = self.fetch('/', body='abcdefghijklmnopqrstuvwxyz', + method='POST') + response.rethrow() + self.assertEqual(json_decode(response.body), + dict(methods=['prepare', 'data_received', + 'data_received', 'data_received', + 'post'])) + + +@wsgi_safe +class IncorrectContentLengthTest(SimpleHandlerTestCase): + def get_handlers(self): + test = self + self.server_error = None + + # Manually set a content-length that doesn't match the actual content. + class TooHigh(RequestHandler): + def get(self): + self.set_header("Content-Length", "42") + try: + self.finish("ok") + except Exception as e: + test.server_error = e + raise + + class TooLow(RequestHandler): + def get(self): + self.set_header("Content-Length", "2") + try: + self.finish("hello") + except Exception as e: + test.server_error = e + raise + + return [('/high', TooHigh), + ('/low', TooLow)] + + def test_content_length_too_high(self): + # When the content-length is too high, the connection is simply + # closed without completing the response. An error is logged on + # the server. + with ExpectLog(app_log, "Uncaught exception"): + with ExpectLog(gen_log, + "Cannot send error response after headers written"): + response = self.fetch("/high") + self.assertEqual(response.code, 599) + self.assertEqual(str(self.server_error), + "Tried to write 40 bytes less than Content-Length") + + def test_content_length_too_low(self): + # When the content-length is too low, the connection is closed + # without writing the last chunk, so the client never sees the request + # complete (which would be a framing error). + with ExpectLog(app_log, "Uncaught exception"): + with ExpectLog(gen_log, + "Cannot send error response after headers written"): + response = self.fetch("/low") + self.assertEqual(response.code, 599) + self.assertEqual(str(self.server_error), + "Tried to write more data than Content-Length") + + +class ClientCloseTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + # Simulate a connection closed by the client during + # request processing. The client will see an error, but the + # server should respond gracefully (without logging errors + # because we were unable to write out as many bytes as + # Content-Length said we would) + self.request.connection.stream.close() + self.write('hello') + + def test_client_close(self): + response = self.fetch('/') + self.assertEqual(response.code, 599) + + class SignedValueTest(unittest.TestCase): SECRET = "It's a secret to everybody" diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 3233e59d..7b3c34ce 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -6,7 +6,7 @@ from tornado.concurrent import Future from tornado.httpclient import HTTPError, HTTPRequest from tornado.log import gen_log from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog -from tornado.test.util import unittest, skipOnTravis +from tornado.test.util import unittest from tornado.web import Application, RequestHandler try: @@ -37,7 +37,7 @@ class TestWebSocketHandler(WebSocketHandler): self.close_future = close_future def on_close(self): - self.close_future.set_result(None) + self.close_future.set_result((self.close_code, self.close_reason)) class EchoHandler(TestWebSocketHandler): @@ -47,6 +47,13 @@ class EchoHandler(TestWebSocketHandler): class HeaderHandler(TestWebSocketHandler): def open(self): + try: + # In a websocket context, many RequestHandler methods + # raise RuntimeErrors. + self.set_status(503) + raise Exception("did not get expected exception") + except RuntimeError: + pass self.write_message(self.request.headers.get('X-Test', '')) @@ -55,6 +62,11 @@ class NonWebSocketHandler(RequestHandler): self.write('ok') +class CloseReasonHandler(TestWebSocketHandler): + def open(self): + self.close(1001, "goodbye") + + class WebSocketTest(AsyncHTTPTestCase): def get_app(self): self.close_future = Future() @@ -62,8 +74,15 @@ class WebSocketTest(AsyncHTTPTestCase): ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), ('/header', HeaderHandler, dict(close_future=self.close_future)), + ('/close_reason', CloseReasonHandler, + dict(close_future=self.close_future)), ]) + def test_http_request(self): + # WS server, HTTP client. + response = self.fetch('/echo') + self.assertEqual(response.code, 400) + @gen_test def test_websocket_gen(self): ws = yield websocket_connect( @@ -84,8 +103,9 @@ class WebSocketTest(AsyncHTTPTestCase): ws.read_message(self.stop) response = self.wait().result() self.assertEqual(response, 'hello') + self.close_future.add_done_callback(lambda f: self.stop()) ws.close() - yield self.close_future + self.wait() @gen_test def test_websocket_http_fail(self): @@ -102,30 +122,16 @@ class WebSocketTest(AsyncHTTPTestCase): 'ws://localhost:%d/non_ws' % self.get_http_port(), io_loop=self.io_loop) - @skipOnTravis - @gen_test - def test_websocket_network_timeout(self): - sock, port = bind_unused_port() - sock.close() - with self.assertRaises(HTTPError) as cm: - with ExpectLog(gen_log, ".*"): - yield websocket_connect( - 'ws://localhost:%d/' % port, - io_loop=self.io_loop, - connect_timeout=0.01) - self.assertEqual(cm.exception.code, 599) - @gen_test def test_websocket_network_fail(self): sock, port = bind_unused_port() sock.close() - with self.assertRaises(HTTPError) as cm: + with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect( 'ws://localhost:%d/' % port, io_loop=self.io_loop, connect_timeout=3600) - self.assertEqual(cm.exception.code, 599) @gen_test def test_websocket_close_buffered_data(self): @@ -147,6 +153,97 @@ class WebSocketTest(AsyncHTTPTestCase): ws.close() yield self.close_future + @gen_test + def test_server_close_reason(self): + ws = yield websocket_connect( + 'ws://localhost:%d/close_reason' % self.get_http_port()) + msg = yield ws.read_message() + # A message of None means the other side closed the connection. + self.assertIs(msg, None) + self.assertEqual(ws.close_code, 1001) + self.assertEqual(ws.close_reason, "goodbye") + + @gen_test + def test_client_close_reason(self): + ws = yield websocket_connect( + 'ws://localhost:%d/echo' % self.get_http_port()) + ws.close(1001, 'goodbye') + code, reason = yield self.close_future + self.assertEqual(code, 1001) + self.assertEqual(reason, 'goodbye') + + @gen_test + def test_check_origin_valid_no_path(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + headers = {'Origin': 'http://localhost:%d' % port} + + ws = yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + ws.write_message('hello') + response = yield ws.read_message() + self.assertEqual(response, 'hello') + ws.close() + yield self.close_future + + @gen_test + def test_check_origin_valid_with_path(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + headers = {'Origin': 'http://localhost:%d/something' % port} + + ws = yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + ws.write_message('hello') + response = yield ws.read_message() + self.assertEqual(response, 'hello') + ws.close() + yield self.close_future + + @gen_test + def test_check_origin_invalid_partial_url(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + headers = {'Origin': 'localhost:%d' % port} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + self.assertEqual(cm.exception.code, 403) + + @gen_test + def test_check_origin_invalid(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + # Host is localhost, which should not be accessible from some other + # domain + headers = {'Origin': 'http://somewhereelse.com'} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + + self.assertEqual(cm.exception.code, 403) + + @gen_test + def test_check_origin_invalid_subdomains(self): + port = self.get_http_port() + + url = 'ws://localhost:%d/echo' % port + # Subdomains should be disallowed by default. If we could pass a + # resolver to websocket_connect we could test sibling domains as well. + headers = {'Origin': 'http://subtenant.localhost'} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers), + io_loop=self.io_loop) + + self.assertEqual(cm.exception.code, 403) + class MaskFunctionMixin(object): # Subclasses should define self.mask(mask, data) diff --git a/tornado/test/wsgi_test.py b/tornado/test/wsgi_test.py index 8dc35650..42d74b88 100644 --- a/tornado/test/wsgi_test.py +++ b/tornado/test/wsgi_test.py @@ -5,8 +5,8 @@ from tornado.escape import json_decode from tornado.test.httpserver_test import TypeCheckHandler from tornado.testing import AsyncHTTPTestCase from tornado.util import u -from tornado.web import RequestHandler -from tornado.wsgi import WSGIApplication, WSGIContainer +from tornado.web import RequestHandler, Application +from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter class WSGIContainerTest(AsyncHTTPTestCase): @@ -74,14 +74,27 @@ class WSGIConnectionTest(httpserver_test.HTTPConnectionTest): return WSGIContainer(validator(WSGIApplication(self.get_handlers()))) -def wrap_web_tests(): +def wrap_web_tests_application(): result = {} for cls in web_test.wsgi_safe_tests: - class WSGIWrappedTest(cls): + class WSGIApplicationWrappedTest(cls): def get_app(self): self.app = WSGIApplication(self.get_handlers(), **self.get_app_kwargs()) return WSGIContainer(validator(self.app)) - result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest + result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest return result -globals().update(wrap_web_tests()) +globals().update(wrap_web_tests_application()) + + +def wrap_web_tests_adapter(): + result = {} + for cls in web_test.wsgi_safe_tests: + class WSGIAdapterWrappedTest(cls): + def get_app(self): + self.app = Application(self.get_handlers(), + **self.get_app_kwargs()) + return WSGIContainer(validator(WSGIAdapter(self.app))) + result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest + return result +globals().update(wrap_web_tests_adapter()) diff --git a/tornado/testing.py b/tornado/testing.py index 8355dcfc..b1564aa6 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -17,7 +17,7 @@ try: from tornado.httpclient import AsyncHTTPClient from tornado.httpserver import HTTPServer from tornado.simple_httpclient import SimpleAsyncHTTPClient - from tornado.ioloop import IOLoop + from tornado.ioloop import IOLoop, TimeoutError from tornado import netutil except ImportError: # These modules are not importable on app engine. Parts of this module @@ -38,6 +38,7 @@ import re import signal import socket import sys +import types try: from cStringIO import StringIO # py2 @@ -48,10 +49,16 @@ except ImportError: # (either py27+ or unittest2) so tornado.test.util enforces # this requirement, but for other users of tornado.testing we want # to allow the older version if unitest2 is not available. -try: - import unittest2 as unittest -except ImportError: +if sys.version_info >= (3,): + # On python 3, mixing unittest2 and unittest (including doctest) + # doesn't seem to work, so always use unittest. import unittest +else: + # On python 2, prefer unittest2 when available. + try: + import unittest2 as unittest + except ImportError: + import unittest _next_port = 10000 @@ -95,6 +102,36 @@ def get_async_test_timeout(): return 5 +class _TestMethodWrapper(object): + """Wraps a test method to raise an error if it returns a value. + + This is mainly used to detect undecorated generators (if a test + method yields it must use a decorator to consume the generator), + but will also detect other kinds of return values (these are not + necessarily errors, but we alert anyway since there is no good + reason to return a value from a test. + """ + def __init__(self, orig_method): + self.orig_method = orig_method + + def __call__(self): + result = self.orig_method() + if isinstance(result, types.GeneratorType): + raise TypeError("Generator test methods should be decorated with " + "tornado.testing.gen_test") + elif result is not None: + raise ValueError("Return value from test method ignored: %r" % + result) + + def __getattr__(self, name): + """Proxy all unknown attributes to the original method. + + This is important for some of the decorators in the `unittest` + module, such as `unittest.skipIf`. + """ + return getattr(self.orig_method, name) + + class AsyncTestCase(unittest.TestCase): """`~unittest.TestCase` subclass for testing `.IOLoop`-based asynchronous code. @@ -157,14 +194,20 @@ class AsyncTestCase(unittest.TestCase): self.assertIn("FriendFeed", response.body) self.stop() """ - def __init__(self, *args, **kwargs): - super(AsyncTestCase, self).__init__(*args, **kwargs) + def __init__(self, methodName='runTest', **kwargs): + super(AsyncTestCase, self).__init__(methodName, **kwargs) self.__stopped = False self.__running = False self.__failure = None self.__stop_args = None self.__timeout = None + # It's easy to forget the @gen_test decorator, but if you do + # the test will silently be ignored because nothing will consume + # the generator. Replace the test method with a wrapper that will + # make sure it's not an undecorated generator. + setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName))) + def setUp(self): super(AsyncTestCase, self).setUp() self.io_loop = self.get_new_ioloop() @@ -352,6 +395,7 @@ class AsyncHTTPTestCase(AsyncTestCase): def tearDown(self): self.http_server.stop() + self.io_loop.run_sync(self.http_server.close_all_connections) if (not IOLoop.initialized() or self.http_client.io_loop is not IOLoop.instance()): self.http_client.close() @@ -414,18 +458,50 @@ def gen_test(func=None, timeout=None): .. versionadded:: 3.1 The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment variable. + + .. versionchanged:: 4.0 + The wrapper now passes along ``*args, **kwargs`` so it can be used + on functions with arguments. """ if timeout is None: timeout = get_async_test_timeout() def wrap(f): - f = gen.coroutine(f) - + # Stack up several decorators to allow us to access the generator + # object itself. In the innermost wrapper, we capture the generator + # and save it in an attribute of self. Next, we run the wrapped + # function through @gen.coroutine. Finally, the coroutine is + # wrapped again to make it synchronous with run_sync. + # + # This is a good case study arguing for either some sort of + # extensibility in the gen decorators or cancellation support. @functools.wraps(f) - def wrapper(self): - return self.io_loop.run_sync( - functools.partial(f, self), timeout=timeout) - return wrapper + def pre_coroutine(self, *args, **kwargs): + result = f(self, *args, **kwargs) + if isinstance(result, types.GeneratorType): + self._test_generator = result + else: + self._test_generator = None + return result + + coro = gen.coroutine(pre_coroutine) + + @functools.wraps(coro) + def post_coroutine(self, *args, **kwargs): + try: + return self.io_loop.run_sync( + functools.partial(coro, self, *args, **kwargs), + timeout=timeout) + except TimeoutError as e: + # run_sync raises an error with an unhelpful traceback. + # If we throw it back into the generator the stack trace + # will be replaced by the point where the test is stopped. + self._test_generator.throw(e) + # In case the test contains an overly broad except clause, + # we may get back here. In this case re-raise the original + # exception, which is better than nothing. + raise + return post_coroutine if func is not None: # Used like: diff --git a/tornado/util.py b/tornado/util.py index 469f19ea..49eea2c3 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -41,7 +41,7 @@ class ObjectDict(dict): class GzipDecompressor(object): """Streaming gzip decompressor. - The interface is like that of `zlib.decompressobj` (without the + The interface is like that of `zlib.decompressobj` (without some of the optional arguments, but it understands gzip headers and checksums. """ def __init__(self): @@ -50,14 +50,24 @@ class GzipDecompressor(object): # This works on cpython and pypy, but not jython. self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS) - def decompress(self, value): + def decompress(self, value, max_length=None): """Decompress a chunk, returning newly-available data. Some data may be buffered for later processing; `flush` must be called when there is no more input data to ensure that all data was processed. + + If ``max_length`` is given, some input data may be left over + in ``unconsumed_tail``; you must retrieve this value and pass + it back to a future call to `decompress` if it is not empty. """ - return self.decompressobj.decompress(value) + return self.decompressobj.decompress(value, max_length) + + @property + def unconsumed_tail(self): + """Returns the unconsumed portion left over + """ + return self.decompressobj.unconsumed_tail def flush(self): """Return any remaining buffered data not yet returned by decompress. @@ -90,10 +100,6 @@ def import_object(name): return __import__(name, None, None) parts = name.split('.') - - imp = 'from ' + '.'.join(parts[:-1]) + ' import ' + parts[-1] - #exec(imp) - obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0) try: return getattr(obj, parts[-1]) @@ -144,6 +150,24 @@ def exec_in(code, glob, loc=None): """) +def errno_from_exception(e): + """Provides the errno from an Exception object. + + There are cases that the errno attribute was not set so we pull + the errno out of the args but if someone instatiates an Exception + without any args you will get a tuple error. So this function + abstracts all that behavior to give you a safe way to get the + errno. + """ + + if hasattr(e, 'errno'): + return e.errno + elif e.args: + return e.args[0] + else: + return None + + class Configurable(object): """Base class for configurable interfaces. @@ -255,6 +279,16 @@ class ArgReplacer(object): # Not a positional parameter self.arg_pos = None + def get_old_value(self, args, kwargs, default=None): + """Returns the old value of the named argument without replacing it. + + Returns ``default`` if the argument is not present. + """ + if self.arg_pos is not None and len(args) > self.arg_pos: + return args[self.arg_pos] + else: + return kwargs.get(self.name, default) + def replace(self, new_value, args, kwargs): """Replace the named argument in ``args, kwargs`` with ``new_value``. diff --git a/tornado/web.py b/tornado/web.py index ed89e7f9..209b7ecd 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -73,9 +73,11 @@ import tornado import traceback import types -from tornado.concurrent import Future +from tornado.concurrent import Future, is_future from tornado import escape +from tornado import gen from tornado import httputil +from tornado import iostream from tornado import locale from tornado.log import access_log, app_log, gen_log from tornado import stack_context @@ -160,6 +162,7 @@ class RequestHandler(object): self._finished = False self._auto_finish = True self._transforms = None # will be set in _execute + self._prepared_future = None self.path_args = None self.path_kwargs = None self.ui = ObjectDict((n, self._ui_method(m)) for n, m in @@ -173,10 +176,7 @@ class RequestHandler(object): application.ui_modules) self.ui["modules"] = self.ui["_tt_modules"] self.clear() - # Check since connection is not available in WSGI - if getattr(self.request, "connection", None): - self.request.connection.set_close_callback( - self.on_connection_close) + self.request.connection.set_close_callback(self.on_connection_close) self.initialize(**kwargs) def initialize(self): @@ -267,7 +267,9 @@ class RequestHandler(object): may not be called promptly after the end user closes their connection. """ - pass + if _has_stream_request_body(self.__class__): + if not self.request.body.done(): + self.request.body.set_exception(iostream.StreamClosedError()) def clear(self): """Resets all headers and content for this response.""" @@ -277,12 +279,6 @@ class RequestHandler(object): "Date": httputil.format_timestamp(time.time()), }) self.set_default_headers() - if (not self.request.supports_http_1_1() and - getattr(self.request, 'connection', None) and - not self.request.connection.no_keep_alive): - conn_header = self.request.headers.get("Connection") - if conn_header and (conn_header.lower() == "keep-alive"): - self._headers["Connection"] = "Keep-Alive" self._write_buffer = [] self._status_code = 200 self._reason = httputil.responses[200] @@ -487,7 +483,7 @@ class RequestHandler(object): @property def cookies(self): - """An alias for `self.request.cookies <.httpserver.HTTPRequest.cookies>`.""" + """An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" return self.request.cookies def get_cookie(self, name, default=None): @@ -649,12 +645,15 @@ class RequestHandler(object): Note that lists are not converted to JSON because of a potential cross-site security vulnerability. All JSON output should be wrapped in a dictionary. More details at - http://haacked.com/archive/2008/11/20/anatomy-of-a-subtle-json-vulnerability.aspx + http://haacked.com/archive/2009/06/25/json-hijacking.aspx/ and + https://github.com/facebook/tornado/issues/1009 """ if self._finished: raise RuntimeError("Cannot write() after finish(). May be caused " "by using async operations without the " "@asynchronous decorator.") + if not isinstance(chunk, (bytes_type, unicode_type, dict)): + raise TypeError("write() only accepts bytes, unicode, and dict objects") if isinstance(chunk, dict): chunk = escape.json_encode(chunk) self.set_header("Content-Type", "application/json; charset=UTF-8") @@ -820,35 +819,44 @@ class RequestHandler(object): Note that only one flush callback can be outstanding at a time; if another flush occurs before the previous flush's callback has been run, the previous callback will be discarded. - """ - if self.application._wsgi: - # WSGI applications cannot usefully support flush, so just make - # it a no-op (and run the callback immediately). - if callback is not None: - callback() - return + .. versionchanged:: 4.0 + Now returns a `.Future` if no callback is given. + """ chunk = b"".join(self._write_buffer) self._write_buffer = [] if not self._headers_written: self._headers_written = True - for transform in self._transforms or []: + for transform in self._transforms: self._status_code, self._headers, chunk = \ transform.transform_first_chunk( self._status_code, self._headers, chunk, include_footers) - headers = self._generate_headers() + # Ignore the chunk and only write the headers for HEAD requests + if self.request.method == "HEAD": + chunk = None + + # Finalize the cookie headers (which have been stored in a side + # object so an outgoing cookie could be overwritten before it + # is sent). + if hasattr(self, "_new_cookie"): + for cookie in self._new_cookie.values(): + self.add_header("Set-Cookie", cookie.OutputString(None)) + + start_line = httputil.ResponseStartLine(self.request.version, + self._status_code, + self._reason) + return self.request.connection.write_headers( + start_line, self._headers, chunk, callback=callback) else: for transform in self._transforms: chunk = transform.transform_chunk(chunk, include_footers) - headers = b"" - - # Ignore the chunk and only write the headers for HEAD requests - if self.request.method == "HEAD": - if headers: - self.request.write(headers, callback=callback) - return - - self.request.write(headers + chunk, callback=callback) + # Ignore the chunk and only write the headers for HEAD requests + if self.request.method != "HEAD": + return self.request.connection.write(chunk, callback=callback) + else: + future = Future() + future.set_result(None) + return future def finish(self, chunk=None): """Finishes this response, ending the HTTP request.""" @@ -884,10 +892,9 @@ class RequestHandler(object): # are keepalive connections) self.request.connection.set_close_callback(None) - if not self.application._wsgi: - self.flush(include_footers=True) - self.request.finish() - self._log() + self.flush(include_footers=True) + self.request.finish() + self._log() self._finished = True self.on_finish() # Break up a reference cycle between this handler and the @@ -1235,27 +1242,6 @@ class RequestHandler(object): return base + get_url(self.settings, path, **kwargs) - def async_callback(self, callback, *args, **kwargs): - """Obsolete - catches exceptions from the wrapped function. - - This function is unnecessary since Tornado 1.1. - """ - if callback is None: - return None - if args or kwargs: - callback = functools.partial(callback, *args, **kwargs) - - def wrapper(*args, **kwargs): - try: - return callback(*args, **kwargs) - except Exception as e: - if self._headers_written: - app_log.error("Exception after headers written", - exc_info=True) - else: - self._handle_request_exception(e) - return wrapper - def require_setting(self, name, feature="this feature"): """Raises an exception if the given app setting is not defined.""" if not self.application.settings.get(name): @@ -1322,6 +1308,7 @@ class RequestHandler(object): self._handle_request_exception(value) return True + @gen.coroutine def _execute(self, transforms, *args, **kwargs): """Executes this request with the given output transforms.""" self._transforms = transforms @@ -1336,52 +1323,52 @@ class RequestHandler(object): if self.request.method not in ("GET", "HEAD", "OPTIONS") and \ self.application.settings.get("xsrf_cookies"): self.check_xsrf_cookie() - self._when_complete(self.prepare(), self._execute_method) - except Exception as e: - self._handle_request_exception(e) - def _when_complete(self, result, callback): - try: - if result is None: - callback() - elif isinstance(result, Future): - if result.done(): - if result.result() is not None: - raise ValueError('Expected None, got %r' % result.result()) - callback() - else: - # Delayed import of IOLoop because it's not available - # on app engine - from tornado.ioloop import IOLoop - IOLoop.current().add_future( - result, functools.partial(self._when_complete, - callback=callback)) - else: - raise ValueError("Expected Future or None, got %r" % result) - except Exception as e: - self._handle_request_exception(e) + result = self.prepare() + if is_future(result): + result = yield result + if result is not None: + raise TypeError("Expected None, got %r" % result) + if self._prepared_future is not None: + # Tell the Application we've finished with prepare() + # and are ready for the body to arrive. + self._prepared_future.set_result(None) + if self._finished: + return + + if _has_stream_request_body(self.__class__): + # In streaming mode request.body is a Future that signals + # the body has been completely received. The Future has no + # result; the data has been passed to self.data_received + # instead. + try: + yield self.request.body + except iostream.StreamClosedError: + return - def _execute_method(self): - if not self._finished: method = getattr(self, self.request.method.lower()) - self._when_complete(method(*self.path_args, **self.path_kwargs), - self._execute_finish) + result = method(*self.path_args, **self.path_kwargs) + if is_future(result): + result = yield result + if result is not None: + raise TypeError("Expected None, got %r" % result) + if self._auto_finish and not self._finished: + self.finish() + except Exception as e: + self._handle_request_exception(e) + if (self._prepared_future is not None and + not self._prepared_future.done()): + # In case we failed before setting _prepared_future, do it + # now (to unblock the HTTP server). Note that this is not + # in a finally block to avoid GC issues prior to Python 3.4. + self._prepared_future.set_result(None) - def _execute_finish(self): - if self._auto_finish and not self._finished: - self.finish() + def data_received(self, chunk): + """Implement this method to handle streamed request data. - def _generate_headers(self): - reason = self._reason - lines = [utf8(self.request.version + " " + - str(self._status_code) + - " " + reason)] - lines.extend([utf8(n) + b": " + utf8(v) for n, v in self._headers.get_all()]) - - if hasattr(self, "_new_cookie"): - for cookie in self._new_cookie.values(): - lines.append(utf8("Set-Cookie: " + cookie.OutputString(None))) - return b"\r\n".join(lines) + b"\r\n\r\n" + Requires the `.stream_request_body` decorator. + """ + raise NotImplementedError() def _log(self): """Logs the current request. @@ -1495,8 +1482,6 @@ def asynchronous(method): from tornado.ioloop import IOLoop @functools.wraps(method) def wrapper(self, *args, **kwargs): - if self.application._wsgi: - raise Exception("@asynchronous is not supported for WSGI apps") self._auto_finish = False with stack_context.ExceptionStackContext( self._stack_context_handle_exception): @@ -1523,6 +1508,40 @@ def asynchronous(method): return wrapper +def stream_request_body(cls): + """Apply to `RequestHandler` subclasses to enable streaming body support. + + This decorator implies the following changes: + + * `.HTTPServerRequest.body` is undefined, and body arguments will not + be included in `RequestHandler.get_argument`. + * `RequestHandler.prepare` is called when the request headers have been + read instead of after the entire body has been read. + * The subclass must define a method ``data_received(self, data):``, which + will be called zero or more times as data is available. Note that + if the request has an empty body, ``data_received`` may not be called. + * ``prepare`` and ``data_received`` may return Futures (such as via + ``@gen.coroutine``, in which case the next method will not be called + until those futures have completed. + * The regular HTTP method (``post``, ``put``, etc) will be called after + the entire body has been read. + + There is a subtle interaction between ``data_received`` and asynchronous + ``prepare``: The first call to ``data_recieved`` may occur at any point + after the call to ``prepare`` has returned *or yielded*. + """ + if not issubclass(cls, RequestHandler): + raise TypeError("expected subclass of RequestHandler, got %r", cls) + cls._stream_request_body = True + return cls + + +def _has_stream_request_body(cls): + if not issubclass(cls, RequestHandler): + raise TypeError("expected subclass of RequestHandler, got %r", cls) + return getattr(cls, '_stream_request_body', False) + + def removeslash(method): """Use this decorator to remove trailing slashes from the request path. @@ -1567,7 +1586,7 @@ def addslash(method): return wrapper -class Application(object): +class Application(httputil.HTTPServerConnectionDelegate): """A collection of request handlers that make up a web application. Instances of this class are callable and can be passed directly to @@ -1619,12 +1638,11 @@ class Application(object): """ def __init__(self, handlers=None, default_host="", transforms=None, - wsgi=False, **settings): + **settings): if transforms is None: self.transforms = [] if settings.get("gzip"): self.transforms.append(GZipContentEncoding) - self.transforms.append(ChunkedTransferEncoding) else: self.transforms = transforms self.handlers = [] @@ -1636,7 +1654,6 @@ class Application(object): 'Template': TemplateModule, } self.ui_methods = {} - self._wsgi = wsgi self._load_ui_modules(settings.get("ui_modules", {})) self._load_ui_methods(settings.get("ui_methods", {})) if self.settings.get("static_path"): @@ -1662,7 +1679,7 @@ class Application(object): self.settings.setdefault('serve_traceback', True) # Automatically reload modified modules - if self.settings.get('autoreload') and not wsgi: + if self.settings.get('autoreload'): from tornado import autoreload autoreload.start() @@ -1762,64 +1779,15 @@ class Application(object): except TypeError: pass + def start_request(self, connection): + # Modern HTTPServer interface + return _RequestDispatcher(self, connection) + def __call__(self, request): - """Called by HTTPServer to execute the request.""" - transforms = [t(request) for t in self.transforms] - handler = None - args = [] - kwargs = {} - handlers = self._get_host_handlers(request) - if not handlers: - handler = RedirectHandler( - self, request, url="http://" + self.default_host + "/") - else: - for spec in handlers: - match = spec.regex.match(request.path) - if match: - handler = spec.handler_class(self, request, **spec.kwargs) - if spec.regex.groups: - # None-safe wrapper around url_unescape to handle - # unmatched optional groups correctly - def unquote(s): - if s is None: - return s - return escape.url_unescape(s, encoding=None, - plus=False) - # Pass matched groups to the handler. Since - # match.groups() includes both named and unnamed groups, - # we want to use either groups or groupdict but not both. - # Note that args are passed as bytes so the handler can - # decide what encoding to use. - - if spec.regex.groupindex: - kwargs = dict( - (str(k), unquote(v)) - for (k, v) in match.groupdict().items()) - else: - args = [unquote(s) for s in match.groups()] - break - if not handler: - if self.settings.get('default_handler_class'): - handler_class = self.settings['default_handler_class'] - handler_args = self.settings.get( - 'default_handler_args', {}) - else: - handler_class = ErrorHandler - handler_args = dict(status_code=404) - handler = handler_class(self, request, **handler_args) - - # If template cache is disabled (usually in the debug mode), - # re-compile templates and reload static files on every - # request so you don't need to restart to see changes - if not self.settings.get("compiled_template_cache", True): - with RequestHandler._template_loader_lock: - for loader in RequestHandler._template_loaders.values(): - loader.reset() - if not self.settings.get('static_hash_cache', True): - StaticFileHandler.reset() - - handler._execute(transforms, *args, **kwargs) - return handler + # Legacy HTTPServer interface + dispatcher = _RequestDispatcher(self, None) + dispatcher.set_request(request) + return dispatcher.execute() def reverse_url(self, name, *args): """Returns a URL path for handler named ``name`` @@ -1856,6 +1824,113 @@ class Application(object): handler._request_summary(), request_time) +class _RequestDispatcher(httputil.HTTPMessageDelegate): + def __init__(self, application, connection): + self.application = application + self.connection = connection + self.request = None + self.chunks = [] + self.handler_class = None + self.handler_kwargs = None + self.path_args = [] + self.path_kwargs = {} + + def headers_received(self, start_line, headers): + self.set_request(httputil.HTTPServerRequest( + connection=self.connection, start_line=start_line, headers=headers)) + if self.stream_request_body: + self.request.body = Future() + return self.execute() + + def set_request(self, request): + self.request = request + self._find_handler() + self.stream_request_body = _has_stream_request_body(self.handler_class) + + def _find_handler(self): + # Identify the handler to use as soon as we have the request. + # Save url path arguments for later. + app = self.application + handlers = app._get_host_handlers(self.request) + if not handlers: + self.handler_class = RedirectHandler + self.handler_kwargs = dict(url="http://" + app.default_host + "/") + return + for spec in handlers: + match = spec.regex.match(self.request.path) + if match: + self.handler_class = spec.handler_class + self.handler_kwargs = spec.kwargs + if spec.regex.groups: + # Pass matched groups to the handler. Since + # match.groups() includes both named and + # unnamed groups, we want to use either groups + # or groupdict but not both. + if spec.regex.groupindex: + self.path_kwargs = dict( + (str(k), _unquote_or_none(v)) + for (k, v) in match.groupdict().items()) + else: + self.path_args = [_unquote_or_none(s) + for s in match.groups()] + return + if app.settings.get('default_handler_class'): + self.handler_class = app.settings['default_handler_class'] + self.handler_kwargs = app.settings.get( + 'default_handler_args', {}) + else: + self.handler_class = ErrorHandler + self.handler_kwargs = dict(status_code=404) + + def data_received(self, data): + if self.stream_request_body: + return self.handler.data_received(data) + else: + self.chunks.append(data) + + def finish(self): + if self.stream_request_body: + self.request.body.set_result(None) + else: + self.request.body = b''.join(self.chunks) + self.request._parse_body() + self.execute() + + def on_connection_close(self): + if self.stream_request_body: + self.handler.on_connection_close() + else: + self.chunks = None + + def execute(self): + # If template cache is disabled (usually in the debug mode), + # re-compile templates and reload static files on every + # request so you don't need to restart to see changes + if not self.application.settings.get("compiled_template_cache", True): + with RequestHandler._template_loader_lock: + for loader in RequestHandler._template_loaders.values(): + loader.reset() + if not self.application.settings.get('static_hash_cache', True): + StaticFileHandler.reset() + + self.handler = self.handler_class(self.application, self.request, + **self.handler_kwargs) + transforms = [t(self.request) for t in self.application.transforms] + + if self.stream_request_body: + self.handler._prepared_future = Future() + # Note that if an exception escapes handler._execute it will be + # trapped in the Future it returns (which we are ignoring here). + # However, that shouldn't happen because _execute has a blanket + # except handler, and we cannot easily access the IOLoop here to + # call add_future. + self.handler._execute(transforms, *self.path_args, **self.path_kwargs) + # If we are streaming the request body, then execute() is finished + # when the handler has prepared to receive the body. If not, + # it doesn't matter when execute() finishes (so we return None) + return self.handler._prepared_future + + class HTTPError(Exception): """An exception that will turn into an HTTP error response. @@ -2014,8 +2089,9 @@ class StaticFileHandler(RequestHandler): cls._static_hashes = {} def head(self, path): - self.get(path, include_body=False) + return self.get(path, include_body=False) + @gen.coroutine def get(self, path, include_body=True): # Set up our path instance variables. self.path = self.parse_url_path(path) @@ -2040,9 +2116,9 @@ class StaticFileHandler(RequestHandler): # the request will be treated as if the header didn't exist. request_range = httputil._parse_request_range(range_header) + size = self.get_content_size() if request_range: start, end = request_range - size = self.get_content_size() if (start is not None and start >= size) or end == 0: # As per RFC 2616 14.35.1, a range is not satisfiable only: if # the first requested byte is equal to or greater than the @@ -2067,18 +2143,26 @@ class StaticFileHandler(RequestHandler): httputil._get_content_range(start, end, size)) else: start = end = None - content = self.get_content(self.absolute_path, start, end) - if isinstance(content, bytes_type): - content = [content] - content_length = 0 - for chunk in content: - if include_body: + + if start is not None and end is not None: + content_length = end - start + elif end is not None: + content_length = end + elif start is not None: + content_length = size - start + else: + content_length = size + self.set_header("Content-Length", content_length) + + if include_body: + content = self.get_content(self.absolute_path, start, end) + if isinstance(content, bytes_type): + content = [content] + for chunk in content: self.write(chunk) - else: - content_length += len(chunk) - if not include_body: + yield self.flush() + else: assert self.request.method == "HEAD" - self.set_header("Content-Length", content_length) def compute_etag(self): """Sets the ``Etag`` header based on static url version. @@ -2258,10 +2342,13 @@ class StaticFileHandler(RequestHandler): def get_content_size(self): """Retrieve the total size of the resource at the given path. - This method may be overridden by subclasses. It will only - be called if a partial result is requested from `get_content` + This method may be overridden by subclasses. .. versionadded:: 3.1 + + .. versionchanged:: 4.0 + This method is now always called, instead of only when + partial results are requested. """ stat_result = self._stat() return stat_result[stat.ST_SIZE] @@ -2383,7 +2470,7 @@ class FallbackHandler(RequestHandler): """A `RequestHandler` that wraps another HTTP server callback. The fallback is a callable object that accepts an - `~.httpserver.HTTPRequest`, such as an `Application` or + `~.httputil.HTTPServerRequest`, such as an `Application` or `tornado.wsgi.WSGIContainer`. This is most useful to use both Tornado ``RequestHandlers`` and WSGI in the same server. Typical usage:: @@ -2407,7 +2494,7 @@ class OutputTransform(object): """A transform modifies the result of an HTTP request (e.g., GZip encoding) A new transform instance is created for every request. See the - ChunkedTransferEncoding example below if you want to implement a + GZipContentEncoding example below if you want to implement a new Transform. """ def __init__(self, request): @@ -2424,16 +2511,24 @@ class GZipContentEncoding(OutputTransform): """Applies the gzip content encoding to the response. See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.11 + + .. versionchanged:: 4.0 + Now compresses all mime types beginning with ``text/``, instead + of just a whitelist. (the whitelist is still used for certain + non-text mime types). """ - CONTENT_TYPES = set([ - "text/plain", "text/html", "text/css", "text/xml", "application/javascript", - "application/x-javascript", "application/xml", "application/atom+xml", - "text/javascript", "application/json", "application/xhtml+xml"]) + # Whitelist of compressible mime types (in addition to any types + # beginning with "text/"). + CONTENT_TYPES = set(["application/javascript", "application/x-javascript", + "application/xml", "application/atom+xml", + "application/json", "application/xhtml+xml"]) MIN_LENGTH = 5 def __init__(self, request): - self._gzipping = request.supports_http_1_1() and \ - "gzip" in request.headers.get("Accept-Encoding", "") + self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "") + + def _compressible_type(self, ctype): + return ctype.startswith('text/') or ctype in self.CONTENT_TYPES def transform_first_chunk(self, status_code, headers, chunk, finishing): if 'Vary' in headers: @@ -2442,7 +2537,7 @@ class GZipContentEncoding(OutputTransform): headers['Vary'] = b'Accept-Encoding' if self._gzipping: ctype = _unicode(headers.get("Content-Type", "")).split(";")[0] - self._gzipping = (ctype in self.CONTENT_TYPES) and \ + self._gzipping = self._compressible_type(ctype) and \ (not finishing or len(chunk) >= self.MIN_LENGTH) and \ (finishing or "Content-Length" not in headers) and \ ("Content-Encoding" not in headers) @@ -2468,42 +2563,16 @@ class GZipContentEncoding(OutputTransform): return chunk -class ChunkedTransferEncoding(OutputTransform): - """Applies the chunked transfer encoding to the response. - - See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.6.1 - """ - def __init__(self, request): - self._chunking = request.supports_http_1_1() - - def transform_first_chunk(self, status_code, headers, chunk, finishing): - # 304 responses have no body (not even a zero-length body), and so - # should not have either Content-Length or Transfer-Encoding headers. - if self._chunking and status_code != 304: - # No need to chunk the output if a Content-Length is specified - if "Content-Length" in headers or "Transfer-Encoding" in headers: - self._chunking = False - else: - headers["Transfer-Encoding"] = "chunked" - chunk = self.transform_chunk(chunk, finishing) - return status_code, headers, chunk - - def transform_chunk(self, block, finishing): - if self._chunking: - # Don't write out empty chunks because that means END-OF-STREAM - # with chunked encoding - if block: - block = utf8("%x" % len(block)) + b"\r\n" + block + b"\r\n" - if finishing: - block += b"0\r\n\r\n" - return block - - def authenticated(method): """Decorate methods with this to require that the user be logged in. If the user is not logged in, they will be redirected to the configured `login url `. + + If you configure a login url with a query parameter, Tornado will + assume you know what you're doing and use it as-is. If not, it + will add a `next` parameter so the login page knows where to send + you once you're logged in. """ @functools.wraps(method) def wrapper(self, *args, **kwargs): @@ -2810,7 +2879,8 @@ def create_signed_value(secret, name, value, version=None, clock=None): # A leading version number in decimal with no leading zeros, followed by a pipe. _signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$") -def decode_signed_value(secret, name, value, max_age_days=31, clock=None,min_version=None): + +def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None): if clock is None: clock = time.time if min_version is None: @@ -2850,6 +2920,7 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None,min_ver else: return None + def _decode_signed_value_v1(secret, name, value, max_age_days, clock): parts = utf8(value).split(b"|") if len(parts) != 3: @@ -2886,9 +2957,9 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): field_value = rest[:n] # In python 3, indexing bytes returns small integers; we must # use a slice to get a byte string as in python 2. - if rest[n:n+1] != b'|': + if rest[n:n + 1] != b'|': raise ValueError("malformed v2 signed value field") - rest = rest[n+1:] + rest = rest[n + 1:] return field_value, rest rest = value[2:] # remove version number try: @@ -2921,7 +2992,20 @@ def _create_signature_v1(secret, *parts): hash.update(utf8(part)) return utf8(hash.hexdigest()) + def _create_signature_v2(secret, s): hash = hmac.new(utf8(secret), digestmod=hashlib.sha256) hash.update(utf8(s)) return utf8(hash.hexdigest()) + + +def _unquote_or_none(s): + """None-safe wrapper around url_unescape to handle unamteched optional + groups correctly. + + Note that args are passed as bytes so the handler can decide what + encoding to use. + """ + if s is None: + return s + return escape.url_unescape(s, encoding=None, plus=False) diff --git a/tornado/websocket.py b/tornado/websocket.py index ff78552c..19196b88 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -31,15 +31,25 @@ import tornado.escape import tornado.web from tornado.concurrent import TracebackFuture -from tornado.escape import utf8, native_str +from tornado.escape import utf8, native_str, to_unicode from tornado import httpclient, httputil from tornado.ioloop import IOLoop from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log -from tornado.netutil import Resolver from tornado import simple_httpclient +from tornado.tcpclient import TCPClient from tornado.util import bytes_type, unicode_type, _websocket_mask +try: + from urllib.parse import urlparse # py2 +except ImportError: + from urlparse import urlparse # py3 + +try: + xrange # py2 +except NameError: + xrange = range # py3 + class WebSocketError(Exception): pass @@ -102,28 +112,20 @@ class WebSocketHandler(tornado.web.RequestHandler): def __init__(self, application, request, **kwargs): tornado.web.RequestHandler.__init__(self, application, request, **kwargs) - self.stream = request.connection.stream self.ws_connection = None + self.close_code = None + self.close_reason = None + self.stream = None - def _execute(self, transforms, *args, **kwargs): + @tornado.web.asynchronous + def get(self, *args, **kwargs): self.open_args = args self.open_kwargs = kwargs - # Websocket only supports GET method - if self.request.method != 'GET': - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 405 Method Not Allowed\r\n\r\n" - )) - self.stream.close() - return - # Upgrade header should be present and should be equal to WebSocket if self.request.headers.get("Upgrade", "").lower() != 'websocket': - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 400 Bad Request\r\n\r\n" - "Can \"Upgrade\" only to \"WebSocket\"." - )) - self.stream.close() + self.set_status(400) + self.finish("Can \"Upgrade\" only to \"WebSocket\".") return # Connection header should be upgrade. Some proxy servers/load balancers @@ -131,16 +133,31 @@ class WebSocketHandler(tornado.web.RequestHandler): headers = self.request.headers connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(",")) if 'upgrade' not in connection: - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 400 Bad Request\r\n\r\n" - "\"Connection\" must be \"Upgrade\"." - )) - self.stream.close() + self.set_status(400) + self.finish("\"Connection\" must be \"Upgrade\".") return + # Handle WebSocket Origin naming convention differences # The difference between version 8 and 13 is that in 8 the # client sends a "Sec-Websocket-Origin" header and in 13 it's # simply "Origin". + if "Origin" in self.request.headers: + origin = self.request.headers.get("Origin") + else: + origin = self.request.headers.get("Sec-Websocket-Origin", None) + + + # If there was an origin header, check to make sure it matches + # according to check_origin. When the origin is None, we assume it + # did not come from a browser and that it can be passed on. + if origin is not None and not self.check_origin(origin): + self.set_status(403) + self.finish("Cross origin websockets not allowed") + return + + self.stream = self.request.connection.detach() + self.stream.set_close_callback(self.on_connection_close) + if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): self.ws_connection = WebSocketProtocol13(self) self.ws_connection.accept_connection() @@ -154,6 +171,7 @@ class WebSocketHandler(tornado.web.RequestHandler): "Sec-WebSocket-Version: 8\r\n\r\n")) self.stream.close() + def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -214,18 +232,70 @@ class WebSocketHandler(tornado.web.RequestHandler): pass def on_close(self): - """Invoked when the WebSocket is closed.""" + """Invoked when the WebSocket is closed. + + If the connection was closed cleanly and a status code or reason + phrase was supplied, these values will be available as the attributes + ``self.close_code`` and ``self.close_reason``. + + .. versionchanged:: 4.0 + + Added ``close_code`` and ``close_reason`` attributes. + """ pass - def close(self): + def close(self, code=None, reason=None): """Closes this Web Socket. Once the close handshake is successful the socket will be closed. + + ``code`` may be a numeric status code, taken from the values + defined in `RFC 6455 section 7.4.1 + `_. + ``reason`` may be a textual message about why the connection is + closing. These values are made available to the client, but are + not otherwise interpreted by the websocket protocol. + + The ``code`` and ``reason`` arguments are ignored in the "draft76" + protocol version. + + .. versionchanged:: 4.0 + + Added the ``code`` and ``reason`` arguments. """ if self.ws_connection: - self.ws_connection.close() + self.ws_connection.close(code, reason) self.ws_connection = None + def check_origin(self, origin): + """Override to enable support for allowing alternate origins. + + The ``origin`` argument is the value of the ``Origin`` HTTP + header, the url responsible for initiating this request. This + method is not called for clients that do not send this header; + such requests are always allowed (because all browsers that + implement WebSockets support this header, and non-browser + clients do not have the same cross-site security concerns). + + Should return True to accept the request or False to reject it. + By default, rejects all requests with an origin on a host other + than this one. + + This is a security protection against cross site scripting attacks on + browsers, since WebSockets are allowed to bypass the usual same-origin + policies and don't use CORS headers. + + .. versionadded:: 4.0 + """ + parsed_origin = urlparse(origin) + origin = parsed_origin.netloc + origin = origin.lower() + + host = self.request.headers.get("Host") + + # Check to see that origin matches host directly, including ports + return origin == host + def allow_draft76(self): """Override to enable support for the older "draft76" protocol. @@ -269,17 +339,6 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return "wss" if self.request.protocol == "https" else "ws" - def async_callback(self, callback, *args, **kwargs): - """Obsolete - catches exceptions from the wrapped function. - - This function is normally unncecessary thanks to - `tornado.stack_context`. - """ - return self.ws_connection.async_callback(callback, *args, **kwargs) - - def _not_supported(self, *args, **kwargs): - raise Exception("Method not supported for Web Sockets") - def on_connection_close(self): if self.ws_connection: self.ws_connection.on_connection_close() @@ -287,9 +346,17 @@ class WebSocketHandler(tornado.web.RequestHandler): self.on_close() +def _wrap_method(method): + def _disallow_for_websocket(self, *args, **kwargs): + if self.stream is None: + method(self, *args, **kwargs) + else: + raise RuntimeError("Method not supported for Web Sockets") + return _disallow_for_websocket for method in ["write", "redirect", "set_header", "send_error", "set_cookie", "set_status", "flush", "finish"]: - setattr(WebSocketHandler, method, WebSocketHandler._not_supported) + setattr(WebSocketHandler, method, + _wrap_method(getattr(WebSocketHandler, method))) class WebSocketProtocol(object): @@ -302,23 +369,17 @@ class WebSocketProtocol(object): self.client_terminated = False self.server_terminated = False - def async_callback(self, callback, *args, **kwargs): - """Wrap callbacks with this if they are used on asynchronous requests. + def _run_callback(self, callback, *args, **kwargs): + """Runs the given callback with exception handling. - Catches exceptions properly and closes this WebSocket if an exception - is uncaught. + On error, aborts the websocket connection and returns False. """ - if args or kwargs: - callback = functools.partial(callback, *args, **kwargs) - - def wrapper(*args, **kwargs): - try: - return callback(*args, **kwargs) - except Exception: - app_log.error("Uncaught exception in %s", - self.request.path, exc_info=True) - self._abort() - return wrapper + try: + callback(*args, **kwargs) + except Exception: + app_log.error("Uncaught exception in %s", + self.request.path, exc_info=True) + self._abort() def on_connection_close(self): self._abort() @@ -409,7 +470,8 @@ class WebSocketProtocol76(WebSocketProtocol): def _write_response(self, challenge): self.stream.write(challenge) - self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs) + self._run_callback(self.handler.open, *self.handler.open_args, + **self.handler.open_kwargs) self._receive_message() def _handle_websocket_headers(self): @@ -457,8 +519,8 @@ class WebSocketProtocol76(WebSocketProtocol): def _on_end_delimiter(self, frame): if not self.client_terminated: - self.async_callback(self.handler.on_message)( - frame[:-1].decode("utf-8", "replace")) + self._run_callback(self.handler.on_message, + frame[:-1].decode("utf-8", "replace")) if not self.client_terminated: self._receive_message() @@ -483,7 +545,7 @@ class WebSocketProtocol76(WebSocketProtocol): """Send ping frame.""" raise ValueError("Ping messages not supported by this version of websockets") - def close(self): + def close(self, code=None, reason=None): """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): @@ -568,7 +630,8 @@ class WebSocketProtocol13(WebSocketProtocol): "%s" "\r\n" % (self._challenge_response(), subprotocol_header))) - self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs) + self._run_callback(self.handler.open, *self.handler.open_args, + **self.handler.open_kwargs) self._receive_frame() def _write_frame(self, fin, opcode, data): @@ -726,28 +789,40 @@ class WebSocketProtocol13(WebSocketProtocol): except UnicodeDecodeError: self._abort() return - self.async_callback(self.handler.on_message)(decoded) + self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data - self.async_callback(self.handler.on_message)(data) + self._run_callback(self.handler.on_message, decoded) elif opcode == 0x8: # Close self.client_terminated = True + if len(data) >= 2: + self.handler.close_code = struct.unpack('>H', data[:2])[0] + if len(data) > 2: + self.handler.close_reason = to_unicode(data[2:]) self.close() elif opcode == 0x9: # Ping self._write_frame(True, 0xA, data) elif opcode == 0xA: # Pong - self.async_callback(self.handler.on_pong)(data) + self._run_callback(self.handler.on_pong, data) else: self._abort() - def close(self): + def close(self, code=None, reason=None): """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): - self._write_frame(True, 0x8, b"") + if code is None and reason is not None: + code = 1000 # "normal closure" status code + if code is None: + close_data = b'' + else: + close_data = struct.pack('>H', code) + if reason is not None: + close_data += utf8(reason) + self._write_frame(True, 0x8, close_data) self.server_terminated = True if self.client_terminated: if self._waiting is not None: @@ -783,18 +858,25 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): 'Sec-WebSocket-Version': '13', }) - self.resolver = Resolver(io_loop=io_loop) + self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( io_loop, None, request, lambda: None, self._on_http_response, - 104857600, self.resolver) + 104857600, self.tcp_client, 65536) - def close(self): + def close(self, code=None, reason=None): """Closes the websocket connection. + ``code`` and ``reason`` are documented under + `WebSocketHandler.close`. + .. versionadded:: 3.2 + + .. versionchanged:: 4.0 + + Added the ``code`` and ``reason`` arguments. """ if self.protocol is not None: - self.protocol.close() + self.protocol.close(code, reason) self.protocol = None def _on_close(self): @@ -810,8 +892,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.connect_future.set_exception(WebSocketError( "Non-websocket response")) - def _handle_1xx(self, code): - assert code == 101 + def headers_received(self, start_line, headers): + if start_line.code != 101: + return super(WebSocketClientConnection, self).headers_received( + start_line, headers) + + self.headers = headers assert self.headers['Upgrade'].lower() == 'websocket' assert self.headers['Connection'].lower() == 'upgrade' accept = WebSocketProtocol13.compute_accept_value(self.key) @@ -824,6 +910,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.io_loop.remove_timeout(self._timeout) self._timeout = None + self.stream = self.connection.detach() + self.stream.set_close_callback(self._on_close) + self.connect_future.set_result(self) def write_message(self, message, binary=False): diff --git a/tornado/wsgi.py b/tornado/wsgi.py index 98dd8064..6e115e12 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -20,9 +20,9 @@ WSGI is the Python standard for web servers, and allows for interoperability between Tornado and other Python web frameworks and servers. This module provides WSGI support in two ways: -* `WSGIApplication` is a version of `tornado.web.Application` that can run - inside a WSGI server. This is useful for running a Tornado app on another - HTTP server, such as Google App Engine. See the `WSGIApplication` class +* `WSGIAdapter` converts a `tornado.web.Application` to the WSGI application + interface. This is useful for running a Tornado app on another + HTTP server, such as Google App Engine. See the `WSGIAdapter` class documentation for limitations that apply. * `WSGIContainer` lets you run other WSGI applications and frameworks on the Tornado HTTP server. For example, with this class you can mix Django @@ -32,15 +32,14 @@ provides WSGI support in two ways: from __future__ import absolute_import, division, print_function, with_statement import sys -import time -import copy import tornado +from tornado.concurrent import Future from tornado import escape from tornado import httputil from tornado.log import access_log from tornado import web -from tornado.escape import native_str, parse_qs_bytes +from tornado.escape import native_str from tornado.util import bytes_type, unicode_type try: @@ -48,11 +47,6 @@ try: except ImportError: from cStringIO import StringIO as BytesIO # python 2 -try: - import Cookie # py2 -except ImportError: - import http.cookies as Cookie # py3 - try: import urllib.parse as urllib_parse # py3 except ImportError: @@ -83,11 +77,84 @@ else: class WSGIApplication(web.Application): """A WSGI equivalent of `tornado.web.Application`. - `WSGIApplication` is very similar to `tornado.web.Application`, - except no asynchronous methods are supported (since WSGI does not - support non-blocking requests properly). If you call - ``self.flush()`` or other asynchronous methods in your request - handlers running in a `WSGIApplication`, we throw an exception. + .. deprecated:: 4.0 + + Use a regular `.Application` and wrap it in `WSGIAdapter` instead. + """ + def __call__(self, environ, start_response): + return WSGIAdapter(self)(environ, start_response) + + +# WSGI has no facilities for flow control, so just return an already-done +# Future when the interface requires it. +_dummy_future = Future() +_dummy_future.set_result(None) + + +class _WSGIConnection(httputil.HTTPConnection): + def __init__(self, method, start_response, context): + self.method = method + self.start_response = start_response + self.context = context + self._write_buffer = [] + self._finished = False + self._expected_content_remaining = None + self._error = None + + def set_close_callback(self, callback): + # WSGI has no facility for detecting a closed connection mid-request, + # so we can simply ignore the callback. + pass + + def write_headers(self, start_line, headers, chunk=None, callback=None): + if self.method == 'HEAD': + self._expected_content_remaining = 0 + elif 'Content-Length' in headers: + self._expected_content_remaining = int(headers['Content-Length']) + else: + self._expected_content_remaining = None + self.start_response( + '%s %s' % (start_line.code, start_line.reason), + [(native_str(k), native_str(v)) for (k, v) in headers.get_all()]) + if chunk is not None: + self.write(chunk, callback) + elif callback is not None: + callback() + return _dummy_future + + def write(self, chunk, callback=None): + if self._expected_content_remaining is not None: + self._expected_content_remaining -= len(chunk) + if self._expected_content_remaining < 0: + self._error = httputil.HTTPOutputError( + "Tried to write more data than Content-Length") + raise self._error + self._write_buffer.append(chunk) + if callback is not None: + callback() + return _dummy_future + + def finish(self): + if (self._expected_content_remaining is not None and + self._expected_content_remaining != 0): + self._error = httputil.HTTPOutputError( + "Tried to write %d bytes less than Content-Length" % + self._expected_content_remaining) + raise self._error + self._finished = True + + +class _WSGIRequestContext(object): + def __init__(self, remote_ip, protocol): + self.remote_ip = remote_ip + self.protocol = protocol + + def __str__(self): + return self.remote_ip + + +class WSGIAdapter(object): + """Converts a `tornado.web.Application` instance into a WSGI application. Example usage:: @@ -100,121 +167,83 @@ class WSGIApplication(web.Application): self.write("Hello, world") if __name__ == "__main__": - application = tornado.wsgi.WSGIApplication([ + application = tornado.web.Application([ (r"/", MainHandler), ]) - server = wsgiref.simple_server.make_server('', 8888, application) + wsgi_app = tornado.wsgi.WSGIAdapter(application) + server = wsgiref.simple_server.make_server('', 8888, wsgi_app) server.serve_forever() See the `appengine demo - `_ + `_ for an example of using this module to run a Tornado app on Google App Engine. - WSGI applications use the same `.RequestHandler` class, but not - ``@asynchronous`` methods or ``flush()``. This means that it is - not possible to use `.AsyncHTTPClient`, or the `tornado.auth` or - `tornado.websocket` modules. + In WSGI mode asynchronous methods are not supported. This means + that it is not possible to use `.AsyncHTTPClient`, or the + `tornado.auth` or `tornado.websocket` modules. + + .. versionadded:: 4.0 """ - def __init__(self, handlers=None, default_host="", **settings): - web.Application.__init__(self, handlers, default_host, transforms=[], - wsgi=True, **settings) + def __init__(self, application): + if isinstance(application, WSGIApplication): + self.application = lambda request: web.Application.__call__( + application, request) + else: + self.application = application def __call__(self, environ, start_response): - handler = web.Application.__call__(self, HTTPRequest(environ)) - assert handler._finished - reason = handler._reason - status = str(handler._status_code) + " " + reason - headers = list(handler._headers.get_all()) - if hasattr(handler, "_new_cookie"): - for cookie in handler._new_cookie.values(): - headers.append(("Set-Cookie", cookie.OutputString(None))) - start_response(status, - [(native_str(k), native_str(v)) for (k, v) in headers]) - return handler._write_buffer - - -class HTTPRequest(object): - """Mimics `tornado.httpserver.HTTPRequest` for WSGI applications.""" - def __init__(self, environ): - """Parses the given WSGI environment to construct the request.""" - self.method = environ["REQUEST_METHOD"] - self.path = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", ""))) - self.path += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", ""))) - self.uri = self.path - self.arguments = {} - self.query_arguments = {} - self.body_arguments = {} - self.query = environ.get("QUERY_STRING", "") - if self.query: - self.uri += "?" + self.query - self.arguments = parse_qs_bytes(native_str(self.query), - keep_blank_values=True) - self.query_arguments = copy.deepcopy(self.arguments) - self.version = "HTTP/1.1" - self.headers = httputil.HTTPHeaders() + method = environ["REQUEST_METHOD"] + uri = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", ""))) + uri += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", ""))) + if environ.get("QUERY_STRING"): + uri += "?" + environ["QUERY_STRING"] + headers = httputil.HTTPHeaders() if environ.get("CONTENT_TYPE"): - self.headers["Content-Type"] = environ["CONTENT_TYPE"] + headers["Content-Type"] = environ["CONTENT_TYPE"] if environ.get("CONTENT_LENGTH"): - self.headers["Content-Length"] = environ["CONTENT_LENGTH"] + headers["Content-Length"] = environ["CONTENT_LENGTH"] for key in environ: if key.startswith("HTTP_"): - self.headers[key[5:].replace("_", "-")] = environ[key] - if self.headers.get("Content-Length"): - self.body = environ["wsgi.input"].read( - int(self.headers["Content-Length"])) + headers[key[5:].replace("_", "-")] = environ[key] + if headers.get("Content-Length"): + body = environ["wsgi.input"].read( + int(headers["Content-Length"])) else: - self.body = "" - self.protocol = environ["wsgi.url_scheme"] - self.remote_ip = environ.get("REMOTE_ADDR", "") + body = "" + protocol = environ["wsgi.url_scheme"] + remote_ip = environ.get("REMOTE_ADDR", "") if environ.get("HTTP_HOST"): - self.host = environ["HTTP_HOST"] + host = environ["HTTP_HOST"] else: - self.host = environ["SERVER_NAME"] - - # Parse request body - self.files = {} - httputil.parse_body_arguments(self.headers.get("Content-Type", ""), - self.body, self.body_arguments, self.files) - - for k, v in self.body_arguments.items(): - self.arguments.setdefault(k, []).extend(v) - - self._start_time = time.time() - self._finish_time = None - - def supports_http_1_1(self): - """Returns True if this request supports HTTP/1.1 semantics""" - return self.version == "HTTP/1.1" - - @property - def cookies(self): - """A dictionary of Cookie.Morsel objects.""" - if not hasattr(self, "_cookies"): - self._cookies = Cookie.SimpleCookie() - if "Cookie" in self.headers: - try: - self._cookies.load( - native_str(self.headers["Cookie"])) - except Exception: - self._cookies = None - return self._cookies - - def full_url(self): - """Reconstructs the full URL for this request.""" - return self.protocol + "://" + self.host + self.uri - - def request_time(self): - """Returns the amount of time it took for this request to execute.""" - if self._finish_time is None: - return time.time() - self._start_time - else: - return self._finish_time - self._start_time + host = environ["SERVER_NAME"] + connection = _WSGIConnection(method, start_response, + _WSGIRequestContext(remote_ip, protocol)) + request = httputil.HTTPServerRequest( + method, uri, "HTTP/1.1", headers=headers, body=body, + host=host, connection=connection) + request._parse_body() + self.application(request) + if connection._error: + raise connection._error + if not connection._finished: + raise Exception("request did not finish synchronously") + return connection._write_buffer class WSGIContainer(object): r"""Makes a WSGI-compatible function runnable on Tornado's HTTP server. + .. warning:: + + WSGI is a *synchronous* interface, while Tornado's concurrency model + is based on single-threaded asynchronous execution. This means that + running a WSGI app with Tornado's `WSGIContainer` is *less scalable* + than running the same app in a multi-threaded WSGI server like + ``gunicorn`` or ``uwsgi``. Use `WSGIContainer` only when there are + benefits to combining Tornado and WSGI in the same process that + outweigh the reduced scalability. + Wrap a WSGI function in a `WSGIContainer` and pass it to `.HTTPServer` to run it. For example:: @@ -281,7 +310,7 @@ class WSGIContainer(object): @staticmethod def environ(request): - """Converts a `tornado.httpserver.HTTPRequest` to a WSGI environment. + """Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment. """ hostport = request.host.split(":") if len(hostport) == 2: @@ -327,3 +356,6 @@ class WSGIContainer(object): summary = request.method + " " + request.uri + " (" + \ request.remote_ip + ")" log_method("%d %s %.2fms", status_code, summary, request_time) + + +HTTPRequest = httputil.HTTPServerRequest