1
0
mirror of https://github.com/moparisthebest/SickRage synced 2025-01-07 03:48:02 -05:00

Update Tornado webserver to 4.1dev1 from 4.0b1 and add the certifi lib dependency.

This commit is contained in:
JackDandy 2014-10-14 05:24:01 +01:00
parent 0333a9aa3e
commit 2d18f5b8ab
53 changed files with 6628 additions and 4202 deletions

1
lib/certifi/__init__.py Normal file
View File

@ -0,0 +1 @@
from .core import where

2
lib/certifi/__main__.py Normal file
View File

@ -0,0 +1,2 @@
from certifi import where
print(where())

5134
lib/certifi/cacert.pem Normal file

File diff suppressed because it is too large Load Diff

19
lib/certifi/core.py Normal file
View File

@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
certifi.py
~~~~~~~~~~
This module returns the installation location of cacert.pem.
"""
import os
def where():
f = os.path.split(__file__)[0]
return os.path.join(f, 'cacert.pem')
if __name__ == '__main__':
print(where())

View File

@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch, # is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version # or negative for a release candidate or beta (after the base version
# number has been incremented) # number has been incremented)
version = "4.0b1" version = "4.1.dev1"
version_info = (4, 0, 0, -99) version_info = (4, 1, 0, -100)

View File

@ -76,7 +76,7 @@ from tornado import escape
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado.log import gen_log from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext from tornado.stack_context import ExceptionStackContext
from tornado.util import bytes_type, u, unicode_type, ArgReplacer from tornado.util import u, unicode_type, ArgReplacer
try: try:
import urlparse # py2 import urlparse # py2
@ -333,7 +333,7 @@ class OAuthMixin(object):
The ``callback_uri`` may be omitted if you have previously The ``callback_uri`` may be omitted if you have previously
registered a callback URI with the third-party service. For registered a callback URI with the third-party service. For
some sevices (including Friendfeed), you must use a some services (including Friendfeed), you must use a
previously-registered callback URI and cannot specify a previously-registered callback URI and cannot specify a
callback via this method. callback via this method.
@ -1112,7 +1112,7 @@ class FacebookMixin(object):
args["cancel_url"] = urlparse.urljoin( args["cancel_url"] = urlparse.urljoin(
self.request.full_url(), cancel_uri) self.request.full_url(), cancel_uri)
if extended_permissions: if extended_permissions:
if isinstance(extended_permissions, (unicode_type, bytes_type)): if isinstance(extended_permissions, (unicode_type, bytes)):
extended_permissions = [extended_permissions] extended_permissions = [extended_permissions]
args["req_perms"] = ",".join(extended_permissions) args["req_perms"] = ",".join(extended_permissions)
self.redirect("http://www.facebook.com/login.php?" + self.redirect("http://www.facebook.com/login.php?" +

File diff suppressed because it is too large Load Diff

View File

@ -29,6 +29,7 @@ import sys
from tornado.stack_context import ExceptionStackContext, wrap from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer from tornado.util import raise_exc_info, ArgReplacer
from tornado.log import app_log
try: try:
from concurrent import futures from concurrent import futures
@ -173,8 +174,11 @@ class Future(object):
def _set_done(self): def _set_done(self):
self._done = True self._done = True
for cb in self._callbacks: for cb in self._callbacks:
# TODO: error handling try:
cb(self) cb(self)
except Exception:
app_log.exception('exception calling callback %r for %r',
cb, self)
self._callbacks = None self._callbacks = None
TracebackFuture = Future TracebackFuture = Future

View File

@ -19,10 +19,12 @@
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import collections import collections
import functools
import logging import logging
import pycurl import pycurl
import threading import threading
import time import time
from io import BytesIO
from tornado import httputil from tornado import httputil
from tornado import ioloop from tornado import ioloop
@ -31,12 +33,6 @@ from tornado import stack_context
from tornado.escape import utf8, native_str from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.util import bytes_type
try:
from io import BytesIO # py3
except ImportError:
from cStringIO import StringIO as BytesIO # py2
class CurlAsyncHTTPClient(AsyncHTTPClient): class CurlAsyncHTTPClient(AsyncHTTPClient):
@ -45,7 +41,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
self._multi = pycurl.CurlMulti() self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout) self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket) self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [_curl_create() for i in range(max_clients)] self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:] self._free_list = self._curls[:]
self._requests = collections.deque() self._requests = collections.deque()
self._fds = {} self._fds = {}
@ -211,8 +207,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
"callback": callback, "callback": callback,
"curl_start_time": time.time(), "curl_start_time": time.time(),
} }
_curl_setup_request(curl, request, curl.info["buffer"], self._curl_setup_request(curl, request, curl.info["buffer"],
curl.info["headers"]) curl.info["headers"])
self._multi.add_handle(curl) self._multi.add_handle(curl)
if not started: if not started:
@ -259,6 +255,206 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def handle_callback_exception(self, callback): def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback) self.io_loop.handle_callback_exception(callback)
def _curl_create(self):
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
return curl
def _curl_setup_request(self, curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
curl.setopt(pycurl.HTTPHEADER,
["%s: %s" % (native_str(k), native_str(v))
for k, v in request.headers.get_all()])
curl.setopt(pycurl.HEADERFUNCTION,
functools.partial(self._curl_header_callback,
headers, request.header_callback))
if request.streaming_callback:
write_function = lambda chunk: self.io_loop.add_callback(
request.streaming_callback, chunk)
else:
write_function = buffer.write
if bytes is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
# a fork/port. That version has a bug in which it passes unicode
# strings instead of bytes to the WRITEFUNCTION. This means that
# if you use a WRITEFUNCTION (which tornado always does), you cannot
# download arbitrary binary data. This needs to be fixed in the
# ported pycurl package, but in the meantime this lambda will
# make it work for downloading (utf8) text.
curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
# Handle curl's cryptic options for every individual HTTP method
if request.method == "GET":
if request.body is not None:
raise ValueError('Body must be None for GET request')
elif request.method in ("POST", "PUT") or request.body:
if request.body is None:
raise ValueError(
'Body must not be None for "%s" request'
% request.method)
request_buffer = BytesIO(utf8(request.body))
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
else:
curl.setopt(pycurl.UPLOAD, True)
curl.setopt(pycurl.INFILESIZE, len(request.body))
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
gen_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
gen_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(self, headers, header_callback, header_line):
header_line = native_str(header_line)
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(self, debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
gen_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
class CurlError(HTTPError): class CurlError(HTTPError):
def __init__(self, errno, message): def __init__(self, errno, message):
@ -266,212 +462,6 @@ class CurlError(HTTPError):
self.errno = errno self.errno = errno
def _curl_create():
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, _curl_debug)
return curl
def _curl_setup_request(curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
# Request headers may be either a regular dict or HTTPHeaders object
if isinstance(request.headers, httputil.HTTPHeaders):
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.get_all()])
else:
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.items()])
if request.header_callback:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: request.header_callback(native_str(line)))
else:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: _curl_header_callback(headers,
native_str(line)))
if request.streaming_callback:
write_function = request.streaming_callback
else:
write_function = buffer.write
if bytes_type is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
# a fork/port. That version has a bug in which it passes unicode
# strings instead of bytes to the WRITEFUNCTION. This means that
# if you use a WRITEFUNCTION (which tornado always does), you cannot
# download arbitrary binary data. This needs to be fixed in the
# ported pycurl package, but in the meantime this lambda will
# make it work for downloading (utf8) text.
curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.use_gzip:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
# Handle curl's cryptic options for every individual HTTP method
if request.method in ("POST", "PUT"):
if request.body is None:
raise AssertionError(
'Body must not be empty for "%s" request'
% request.method)
request_buffer = BytesIO(utf8(request.body))
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
if request.method == "POST":
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
else:
curl.setopt(pycurl.INFILESIZE, len(request.body))
elif request.method == "GET":
if request.body is not None:
raise AssertionError('Body must be empty for GET request')
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
gen_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
gen_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(headers, header_line):
# header_line as returned by curl includes the end-of-line characters.
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
gen_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
if __name__ == "__main__": if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient) AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main() main()

View File

@ -25,7 +25,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import re import re
import sys import sys
from tornado.util import bytes_type, unicode_type, basestring_type, u from tornado.util import unicode_type, basestring_type, u
try: try:
from urllib.parse import parse_qs as _parse_qs # py3 from urllib.parse import parse_qs as _parse_qs # py3
@ -187,7 +187,7 @@ else:
return encoded return encoded
_UTF8_TYPES = (bytes_type, type(None)) _UTF8_TYPES = (bytes, type(None))
def utf8(value): def utf8(value):
@ -215,7 +215,7 @@ def to_unicode(value):
""" """
if isinstance(value, _TO_UNICODE_TYPES): if isinstance(value, _TO_UNICODE_TYPES):
return value return value
if not isinstance(value, bytes_type): if not isinstance(value, bytes):
raise TypeError( raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value) "Expected bytes, unicode, or None; got %r" % type(value)
) )
@ -246,7 +246,7 @@ def to_basestring(value):
""" """
if isinstance(value, _BASESTRING_TYPES): if isinstance(value, _BASESTRING_TYPES):
return value return value
if not isinstance(value, bytes_type): if not isinstance(value, bytes):
raise TypeError( raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value) "Expected bytes, unicode, or None; got %r" % type(value)
) )
@ -264,7 +264,7 @@ def recursive_unicode(obj):
return list(recursive_unicode(i) for i in obj) return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple): elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj) return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes_type): elif isinstance(obj, bytes):
return to_unicode(obj) return to_unicode(obj)
else: else:
return obj return obj

View File

@ -29,16 +29,7 @@ could be written with ``gen`` as::
Most asynchronous functions in Tornado return a `.Future`; Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`. yielding this object returns its `~.Future.result`.
For functions that do not return ``Futures``, `Task` works with any You can also yield a list or dict of ``Futures``, which will be
function that takes a ``callback`` keyword argument (most Tornado functions
can be used in either style, although the ``Future`` style is preferred
since it is both shorter and provides better exception handling)::
@gen.coroutine
def get(self):
yield gen.Task(AsyncHTTPClient().fetch, "http://example.com")
You can also yield a list or dict of ``Futures`` and/or ``Tasks``, which will be
started at the same time and run in parallel; a list or dict of results will started at the same time and run in parallel; a list or dict of results will
be returned when they are all finished:: be returned when they are all finished::
@ -54,30 +45,6 @@ be returned when they are all finished::
.. versionchanged:: 3.2 .. versionchanged:: 3.2
Dict support added. Dict support added.
For more complicated interfaces, `Task` can be split into two parts:
`Callback` and `Wait`::
class GenAsyncHandler2(RequestHandler):
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
http_client.fetch("http://example.com",
callback=(yield gen.Callback("key")))
response = yield gen.Wait("key")
do_something_with_response(response)
self.render("template.html")
The ``key`` argument to `Callback` and `Wait` allows for multiple
asynchronous operations to be started at different times and proceed
in parallel: yield several callbacks with different keys, then wait
for them once all the async operations have started.
The result of a `Wait` or `Task` yield expression depends on how the callback
was run. If it was called with no arguments, the result is ``None``. If
it was called with one argument, the result is that argument. If it was
called with more than one argument or any keyword arguments, the result
is an `Arguments` object, which is a named tuple ``(args, kwargs)``.
""" """
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
@ -142,7 +109,10 @@ def engine(func):
raise ReturnValueIgnoredError( raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" % "@gen.engine functions cannot return values: %r" %
(future.result(),)) (future.result(),))
future.add_done_callback(final_callback) # The engine interface doesn't give us any way to return
# errors but to raise them into the stack context.
# Save the stack context here to use when the Future has resolved.
future.add_done_callback(stack_context.wrap(final_callback))
return wrapper return wrapper
@ -169,6 +139,17 @@ def coroutine(func, replace_callback=True):
From the caller's perspective, ``@gen.coroutine`` is similar to From the caller's perspective, ``@gen.coroutine`` is similar to
the combination of ``@return_future`` and ``@gen.engine``. the combination of ``@return_future`` and ``@gen.engine``.
.. warning::
When exceptions occur inside a coroutine, the exception
information will be stored in the `.Future` object. You must
examine the result of the `.Future` object, or the exception
may go unnoticed by your code. This means yielding the function
if called from another coroutine, using something like
`.IOLoop.run_sync` for top-level calls, or passing the `.Future`
to `.IOLoop.add_future`.
""" """
return _make_coroutine_wrapper(func, replace_callback=True) return _make_coroutine_wrapper(func, replace_callback=True)
@ -218,7 +199,18 @@ def _make_coroutine_wrapper(func, replace_callback):
future.set_exc_info(sys.exc_info()) future.set_exc_info(sys.exc_info())
else: else:
Runner(result, future, yielded) Runner(result, future, yielded)
return future try:
return future
finally:
# Subtle memory optimization: if next() raised an exception,
# the future's exc_info contains a traceback which
# includes this stack frame. This creates a cycle,
# which will be collected at the next full GC but has
# been shown to greatly increase memory usage of
# benchmarks (relative to the refcount-based scheme
# used in the absence of cycles). We can avoid the
# cycle by clearing the local variable after we return it.
future = None
future.set_result(result) future.set_result(result)
return future return future
return wrapper return wrapper
@ -252,8 +244,8 @@ class Return(Exception):
class YieldPoint(object): class YieldPoint(object):
"""Base class for objects that may be yielded from the generator. """Base class for objects that may be yielded from the generator.
Applications do not normally need to use this class, but it may be .. deprecated:: 4.0
subclassed to provide additional yielding behavior. Use `Futures <.Future>` instead.
""" """
def start(self, runner): def start(self, runner):
"""Called by the runner after the generator has yielded. """Called by the runner after the generator has yielded.
@ -289,6 +281,9 @@ class Callback(YieldPoint):
The callback may be called with zero or one arguments; if an argument The callback may be called with zero or one arguments; if an argument
is given it will be returned by `Wait`. is given it will be returned by `Wait`.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
""" """
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
@ -305,7 +300,11 @@ class Callback(YieldPoint):
class Wait(YieldPoint): class Wait(YieldPoint):
"""Returns the argument passed to the result of a previous `Callback`.""" """Returns the argument passed to the result of a previous `Callback`.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
@ -326,6 +325,9 @@ class WaitAll(YieldPoint):
a list of results in the same order. a list of results in the same order.
`WaitAll` is equivalent to yielding a list of `Wait` objects. `WaitAll` is equivalent to yielding a list of `Wait` objects.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
""" """
def __init__(self, keys): def __init__(self, keys):
self.keys = keys self.keys = keys
@ -341,20 +343,12 @@ class WaitAll(YieldPoint):
def Task(func, *args, **kwargs): def Task(func, *args, **kwargs):
"""Runs a single asynchronous operation. """Adapts a callback-based asynchronous function for use in coroutines.
Takes a function (and optional additional arguments) and runs it with Takes a function (and optional additional arguments) and runs it with
those arguments plus a ``callback`` keyword argument. The argument passed those arguments plus a ``callback`` keyword argument. The argument passed
to the callback is returned as the result of the yield expression. to the callback is returned as the result of the yield expression.
A `Task` is equivalent to a `Callback`/`Wait` pair (with a unique
key generated automatically)::
result = yield gen.Task(func, args)
func(args, callback=(yield gen.Callback(key)))
result = yield gen.Wait(key)
.. versionchanged:: 4.0 .. versionchanged:: 4.0
``gen.Task`` is now a function that returns a `.Future`, instead of ``gen.Task`` is now a function that returns a `.Future`, instead of
a subclass of `YieldPoint`. It still behaves the same way when a subclass of `YieldPoint`. It still behaves the same way when

View File

@ -21,6 +21,8 @@
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import re
from tornado.concurrent import Future from tornado.concurrent import Future
from tornado.escape import native_str, utf8 from tornado.escape import native_str, utf8
from tornado import gen from tornado import gen
@ -56,7 +58,7 @@ class HTTP1ConnectionParameters(object):
""" """
def __init__(self, no_keep_alive=False, chunk_size=None, def __init__(self, no_keep_alive=False, chunk_size=None,
max_header_size=None, header_timeout=None, max_body_size=None, max_header_size=None, header_timeout=None, max_body_size=None,
body_timeout=None, use_gzip=False): body_timeout=None, decompress=False):
""" """
:arg bool no_keep_alive: If true, always close the connection after :arg bool no_keep_alive: If true, always close the connection after
one request. one request.
@ -65,7 +67,8 @@ class HTTP1ConnectionParameters(object):
:arg float header_timeout: how long to wait for all headers (seconds) :arg float header_timeout: how long to wait for all headers (seconds)
:arg int max_body_size: maximum amount of data for body :arg int max_body_size: maximum amount of data for body
:arg float body_timeout: how long to wait while reading body (seconds) :arg float body_timeout: how long to wait while reading body (seconds)
:arg bool use_gzip: if true, decode incoming ``Content-Encoding: gzip`` :arg bool decompress: if true, decode incoming
``Content-Encoding: gzip``
""" """
self.no_keep_alive = no_keep_alive self.no_keep_alive = no_keep_alive
self.chunk_size = chunk_size or 65536 self.chunk_size = chunk_size or 65536
@ -73,7 +76,7 @@ class HTTP1ConnectionParameters(object):
self.header_timeout = header_timeout self.header_timeout = header_timeout
self.max_body_size = max_body_size self.max_body_size = max_body_size
self.body_timeout = body_timeout self.body_timeout = body_timeout
self.use_gzip = use_gzip self.decompress = decompress
class HTTP1Connection(httputil.HTTPConnection): class HTTP1Connection(httputil.HTTPConnection):
@ -141,7 +144,7 @@ class HTTP1Connection(httputil.HTTPConnection):
Returns a `.Future` that resolves to None after the full response has Returns a `.Future` that resolves to None after the full response has
been read. been read.
""" """
if self.params.use_gzip: if self.params.decompress:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size) delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate) return self._read_message(delegate)
@ -190,8 +193,17 @@ class HTTP1Connection(httputil.HTTPConnection):
skip_body = True skip_body = True
code = start_line.code code = start_line.code
if code == 304: if code == 304:
# 304 responses may include the content-length header
# but do not actually have a body.
# http://tools.ietf.org/html/rfc7230#section-3.3
skip_body = True skip_body = True
if code >= 100 and code < 200: if code >= 100 and code < 200:
# 1xx responses should never indicate the presence of
# a body.
if ('Content-Length' in headers or
'Transfer-Encoding' in headers):
raise httputil.HTTPInputError(
"Response code %d cannot have body" % code)
# TODO: client delegates will get headers_received twice # TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change? # in the case of a 100-continue. Document or change?
yield self._read_message(delegate) yield self._read_message(delegate)
@ -200,7 +212,8 @@ class HTTP1Connection(httputil.HTTPConnection):
not self._write_finished): not self._write_finished):
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body: if not skip_body:
body_future = self._read_body(headers, delegate) body_future = self._read_body(
start_line.code if self.is_client else 0, headers, delegate)
if body_future is not None: if body_future is not None:
if self._body_timeout is None: if self._body_timeout is None:
yield body_future yield body_future
@ -293,6 +306,8 @@ class HTTP1Connection(httputil.HTTPConnection):
self._clear_callbacks() self._clear_callbacks()
stream = self.stream stream = self.stream
self.stream = None self.stream = None
if not self._finish_future.done():
self._finish_future.set_result(None)
return stream return stream
def set_body_timeout(self, timeout): def set_body_timeout(self, timeout):
@ -454,6 +469,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if start_line.version == "HTTP/1.1": if start_line.version == "HTTP/1.1":
return connection_header != "close" return connection_header != "close"
elif ("Content-Length" in headers elif ("Content-Length" in headers
or headers.get("Transfer-Encoding", "").lower() == "chunked"
or start_line.method in ("HEAD", "GET")): or start_line.method in ("HEAD", "GET")):
return connection_header == "keep-alive" return connection_header == "keep-alive"
return False return False
@ -470,7 +486,11 @@ class HTTP1Connection(httputil.HTTPConnection):
self._finish_future.set_result(None) self._finish_future.set_result(None)
def _parse_headers(self, data): def _parse_headers(self, data):
data = native_str(data.decode('latin1')) # The lstrip removes newlines that some implementations sometimes
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data = native_str(data.decode('latin1')).lstrip("\r\n")
eol = data.find("\r\n") eol = data.find("\r\n")
start_line = data[:eol] start_line = data[:eol]
try: try:
@ -481,12 +501,36 @@ class HTTP1Connection(httputil.HTTPConnection):
data[eol:100]) data[eol:100])
return start_line, headers return start_line, headers
def _read_body(self, headers, delegate): def _read_body(self, code, headers, delegate):
content_length = headers.get("Content-Length") if "Content-Length" in headers:
if content_length: if "," in headers["Content-Length"]:
content_length = int(content_length) # Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
"Multiple unequal Content-Lengths: %r" %
headers["Content-Length"])
headers["Content-Length"] = pieces[0]
content_length = int(headers["Content-Length"])
if content_length > self._max_body_size: if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long") raise httputil.HTTPInputError("Content-Length too long")
else:
content_length = None
if code == 204:
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in headers or
content_length not in (None, 0)):
raise httputil.HTTPInputError(
"Response with code %d should not have body" % code)
content_length = 0
if content_length is not None:
return self._read_fixed_body(content_length, delegate) return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding") == "chunked": if headers.get("Transfer-Encoding") == "chunked":
return self._read_chunked_body(delegate) return self._read_chunked_body(delegate)
@ -581,6 +625,9 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
self._delegate.data_received(tail) self._delegate.data_received(tail)
return self._delegate.finish() return self._delegate.finish()
def on_connection_close(self):
return self._delegate.on_connection_close()
class HTTP1ServerConnection(object): class HTTP1ServerConnection(object):
"""An HTTP/1.x server.""" """An HTTP/1.x server."""

View File

@ -33,6 +33,9 @@ information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py). and comments in curl_httpclient.py).
To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
""" """
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
@ -60,7 +63,12 @@ class HTTPClient(object):
response = http_client.fetch("http://www.google.com/") response = http_client.fetch("http://www.google.com/")
print response.body print response.body
except httpclient.HTTPError as e: except httpclient.HTTPError as e:
print "Error:", e # HTTPError is raised for non-200 responses; the response
# can be found in e.response.
print("Error: " + str(e))
except Exception as e:
# Other errors are possible, such as IOError.
print("Error: " + str(e))
http_client.close() http_client.close()
""" """
def __init__(self, async_client_class=None, **kwargs): def __init__(self, async_client_class=None, **kwargs):
@ -279,7 +287,7 @@ class HTTPRequest(object):
request_timeout=20.0, request_timeout=20.0,
follow_redirects=True, follow_redirects=True,
max_redirects=5, max_redirects=5,
use_gzip=True, decompress_response=True,
proxy_password='', proxy_password='',
allow_nonstandard_methods=False, allow_nonstandard_methods=False,
validate_cert=True) validate_cert=True)
@ -296,7 +304,7 @@ class HTTPRequest(object):
validate_cert=None, ca_certs=None, validate_cert=None, ca_certs=None,
allow_ipv6=None, allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None, client_key=None, client_cert=None, body_producer=None,
expect_100_continue=False): expect_100_continue=False, decompress_response=None):
r"""All parameters except ``url`` are optional. r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch :arg string url: URL to fetch
@ -330,7 +338,11 @@ class HTTPRequest(object):
or return the 3xx response? or return the 3xx response?
:arg int max_redirects: Limit for ``follow_redirects`` :arg int max_redirects: Limit for ``follow_redirects``
:arg string user_agent: String to send as ``User-Agent`` header :arg string user_agent: String to send as ``User-Agent`` header
:arg bool use_gzip: Request gzip encoding from the server :arg bool decompress_response: Request a compressed response from
the server and decompress it after downloading. Default is True.
New in Tornado 4.0.
:arg bool use_gzip: Deprecated alias for ``decompress_response``
since Tornado 4.0.
:arg string network_interface: Network interface to use for request. :arg string network_interface: Network interface to use for request.
``curl_httpclient`` only; see note below. ``curl_httpclient`` only; see note below.
:arg callable streaming_callback: If set, ``streaming_callback`` will :arg callable streaming_callback: If set, ``streaming_callback`` will
@ -373,7 +385,6 @@ class HTTPRequest(object):
before sending the request body. Only supported with before sending the request body. Only supported with
simple_httpclient. simple_httpclient.
.. note:: .. note::
When using ``curl_httpclient`` certain options may be When using ``curl_httpclient`` certain options may be
@ -414,7 +425,10 @@ class HTTPRequest(object):
self.follow_redirects = follow_redirects self.follow_redirects = follow_redirects
self.max_redirects = max_redirects self.max_redirects = max_redirects
self.user_agent = user_agent self.user_agent = user_agent
self.use_gzip = use_gzip if decompress_response is not None:
self.decompress_response = decompress_response
else:
self.decompress_response = use_gzip
self.network_interface = network_interface self.network_interface = network_interface
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.header_callback = header_callback self.header_callback = header_callback

View File

@ -50,6 +50,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
import tornado.httpserver import tornado.httpserver
import tornado.ioloop import tornado.ioloop
from tornado import httputil
def handle_request(request): def handle_request(request):
message = "You requested %s\n" % request.uri message = "You requested %s\n" % request.uri
@ -129,13 +130,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
way other than `tornado.netutil.bind_sockets`. way other than `tornado.netutil.bind_sockets`.
.. versionchanged:: 4.0 .. versionchanged:: 4.0
Added ``gzip``, ``chunk_size``, ``max_header_size``, Added ``decompress_request``, ``chunk_size``, ``max_header_size``,
``idle_connection_timeout``, ``body_timeout``, ``max_body_size`` ``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate` arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``. instances as ``request_callback``.
""" """
def __init__(self, request_callback, no_keep_alive=False, io_loop=None, def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None, gzip=False, xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None, chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None, idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None): max_body_size=None, max_buffer_size=None):
@ -144,7 +146,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
self.xheaders = xheaders self.xheaders = xheaders
self.protocol = protocol self.protocol = protocol
self.conn_params = HTTP1ConnectionParameters( self.conn_params = HTTP1ConnectionParameters(
use_gzip=gzip, decompress=decompress_request,
chunk_size=chunk_size, chunk_size=chunk_size,
max_header_size=max_header_size, max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600, header_timeout=idle_connection_timeout or 3600,

View File

@ -33,7 +33,7 @@ import time
from tornado.escape import native_str, parse_qs_bytes, utf8 from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log from tornado.log import gen_log
from tornado.util import ObjectDict, bytes_type from tornado.util import ObjectDict
try: try:
import Cookie # py2 import Cookie # py2
@ -335,7 +335,7 @@ class HTTPServerRequest(object):
# set remote IP and protocol # set remote IP and protocol
context = getattr(connection, 'context', None) context = getattr(connection, 'context', None)
self.remote_ip = getattr(context, 'remote_ip') self.remote_ip = getattr(context, 'remote_ip', None)
self.protocol = getattr(context, 'protocol', "http") self.protocol = getattr(context, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1" self.host = host or self.headers.get("Host") or "127.0.0.1"
@ -379,7 +379,7 @@ class HTTPServerRequest(object):
Use ``request.connection`` and the `.HTTPConnection` methods Use ``request.connection`` and the `.HTTPConnection` methods
to write the response. to write the response.
""" """
assert isinstance(chunk, bytes_type) assert isinstance(chunk, bytes)
self.connection.write(chunk, callback=callback) self.connection.write(chunk, callback=callback)
def finish(self): def finish(self):
@ -562,11 +562,18 @@ class HTTPConnection(object):
def url_concat(url, args): def url_concat(url, args):
"""Concatenate url and argument dictionary regardless of whether """Concatenate url and arguments regardless of whether
url has existing query parameters. url has existing query parameters.
``args`` may be either a dictionary or a list of key-value pairs
(the latter allows for multiple values with the same key.
>>> url_concat("http://example.com/foo", dict(c="d"))
'http://example.com/foo?c=d'
>>> url_concat("http://example.com/foo?a=b", dict(c="d")) >>> url_concat("http://example.com/foo?a=b", dict(c="d"))
'http://example.com/foo?a=b&c=d' 'http://example.com/foo?a=b&c=d'
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
'http://example.com/foo?a=b&c=d&c=d2'
""" """
if not args: if not args:
return url return url
@ -803,6 +810,8 @@ def parse_response_start_line(line):
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py # _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some # The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes. # combinations of semicolons and double quotes.
# It has also been modified to support valueless parameters as seen in
# websocket extension negotiations.
def _parseparam(s): def _parseparam(s):
@ -836,9 +845,31 @@ def _parse_header(line):
value = value[1:-1] value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"') value = value.replace('\\\\', '\\').replace('\\"', '"')
pdict[name] = value pdict[name] = value
else:
pdict[p] = None
return key, pdict return key, pdict
def _encode_header(key, pdict):
"""Inverse of _parse_header.
>>> _encode_header('permessage-deflate',
... {'client_max_window_bits': 15, 'client_no_context_takeover': None})
'permessage-deflate; client_max_window_bits=15; client_no_context_takeover'
"""
if not pdict:
return key
out = [key]
# Sort the parameters just to make it easy to test.
for k, v in sorted(pdict.items()):
if v is None:
out.append(k)
else:
# TODO: quote if necessary.
out.append('%s=%s' % (k, v))
return '; '.join(out)
def doctests(): def doctests():
import doctest import doctest
return doctest.DocTestSuite() return doctest.DocTestSuite()

View File

@ -197,7 +197,7 @@ class IOLoop(Configurable):
An `IOLoop` automatically becomes current for its thread An `IOLoop` automatically becomes current for its thread
when it is started, but it is sometimes useful to call when it is started, but it is sometimes useful to call
`make_current` explictly before starting the `IOLoop`, `make_current` explicitly before starting the `IOLoop`,
so that code run at startup time can find the right so that code run at startup time can find the right
instance. instance.
""" """
@ -477,7 +477,7 @@ class IOLoop(Configurable):
.. versionadded:: 4.0 .. versionadded:: 4.0
""" """
self.call_at(self.time() + delay, callback, *args, **kwargs) return self.call_at(self.time() + delay, callback, *args, **kwargs)
def call_at(self, when, callback, *args, **kwargs): def call_at(self, when, callback, *args, **kwargs):
"""Runs the ``callback`` at the absolute time designated by ``when``. """Runs the ``callback`` at the absolute time designated by ``when``.
@ -493,7 +493,7 @@ class IOLoop(Configurable):
.. versionadded:: 4.0 .. versionadded:: 4.0
""" """
self.add_timeout(when, callback, *args, **kwargs) return self.add_timeout(when, callback, *args, **kwargs)
def remove_timeout(self, timeout): def remove_timeout(self, timeout):
"""Cancels a pending timeout. """Cancels a pending timeout.
@ -724,7 +724,7 @@ class PollIOLoop(IOLoop):
# #
# If someone has already set a wakeup fd, we don't want to # If someone has already set a wakeup fd, we don't want to
# disturb it. This is an issue for twisted, which does its # disturb it. This is an issue for twisted, which does its
# SIGCHILD processing in response to its own wakeup fd being # SIGCHLD processing in response to its own wakeup fd being
# written to. As long as the wakeup fd is registered on the IOLoop, # written to. As long as the wakeup fd is registered on the IOLoop,
# the loop will still wake up and everything should work. # the loop will still wake up and everything should work.
old_wakeup_fd = None old_wakeup_fd = None
@ -754,17 +754,18 @@ class PollIOLoop(IOLoop):
# Do not run anything until we have determined which ones # Do not run anything until we have determined which ones
# are ready, so timeouts that call add_timeout cannot # are ready, so timeouts that call add_timeout cannot
# schedule anything in this iteration. # schedule anything in this iteration.
due_timeouts = []
if self._timeouts: if self._timeouts:
now = self.time() now = self.time()
while self._timeouts: while self._timeouts:
if self._timeouts[0].callback is None: if self._timeouts[0].callback is None:
# the timeout was cancelled # The timeout was cancelled. Note that the
# cancellation check is repeated below for timeouts
# that are cancelled by another timeout or callback.
heapq.heappop(self._timeouts) heapq.heappop(self._timeouts)
self._cancellations -= 1 self._cancellations -= 1
elif self._timeouts[0].deadline <= now: elif self._timeouts[0].deadline <= now:
timeout = heapq.heappop(self._timeouts) due_timeouts.append(heapq.heappop(self._timeouts))
callbacks.append(timeout.callback)
del timeout
else: else:
break break
if (self._cancellations > 512 if (self._cancellations > 512
@ -778,9 +779,12 @@ class PollIOLoop(IOLoop):
for callback in callbacks: for callback in callbacks:
self._run_callback(callback) self._run_callback(callback)
for timeout in due_timeouts:
if timeout.callback is not None:
self._run_callback(timeout.callback)
# Closures may be holding on to a lot of memory, so allow # Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait. # them to be freed before we go into our poll wait.
callbacks = callback = None callbacks = callback = due_timeouts = timeout = None
if self._callbacks: if self._callbacks:
# If any callbacks or timeouts called add_callback, # If any callbacks or timeouts called add_callback,
@ -969,10 +973,11 @@ class PeriodicCallback(object):
if not self._running: if not self._running:
return return
try: try:
self.callback() return self.callback()
except Exception: except Exception:
self.io_loop.handle_callback_exception(self.callback) self.io_loop.handle_callback_exception(self.callback)
self._schedule_next() finally:
self._schedule_next()
def _schedule_next(self): def _schedule_next(self):
if self._running: if self._running:

View File

@ -39,7 +39,7 @@ from tornado import ioloop
from tornado.log import gen_log, app_log from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado import stack_context from tornado import stack_context
from tornado.util import bytes_type, errno_from_exception from tornado.util import errno_from_exception
try: try:
from tornado.platform.posix import _set_nonblocking from tornado.platform.posix import _set_nonblocking
@ -324,7 +324,7 @@ class BaseIOStream(object):
.. versionchanged:: 4.0 .. versionchanged:: 4.0
Now returns a `.Future` if no callback is given. Now returns a `.Future` if no callback is given.
""" """
assert isinstance(data, bytes_type) assert isinstance(data, bytes)
self._check_closed() self._check_closed()
# We use bool(_write_buffer) as a proxy for write_buffer_size>0, # We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer. # so never put empty strings in the buffer.
@ -505,7 +505,7 @@ class BaseIOStream(object):
def wrapper(): def wrapper():
self._pending_callbacks -= 1 self._pending_callbacks -= 1
try: try:
callback(*args) return callback(*args)
except Exception: except Exception:
app_log.error("Uncaught exception, closing connection.", app_log.error("Uncaught exception, closing connection.",
exc_info=True) exc_info=True)
@ -517,7 +517,8 @@ class BaseIOStream(object):
# Re-raise the exception so that IOLoop.handle_callback_exception # Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error # can see it and log the error
raise raise
self._maybe_add_error_listener() finally:
self._maybe_add_error_listener()
# We schedule callbacks to be run on the next IOLoop iteration # We schedule callbacks to be run on the next IOLoop iteration
# rather than running them directly for several reasons: # rather than running them directly for several reasons:
# * Prevents unbounded stack growth when a callback calls an # * Prevents unbounded stack growth when a callback calls an
@ -553,7 +554,7 @@ class BaseIOStream(object):
# Pretend to have a pending callback so that an EOF in # Pretend to have a pending callback so that an EOF in
# _read_to_buffer doesn't trigger an immediate close # _read_to_buffer doesn't trigger an immediate close
# callback. At the end of this method we'll either # callback. At the end of this method we'll either
# estabilsh a real pending callback via # establish a real pending callback via
# _read_from_buffer or run the close callback. # _read_from_buffer or run the close callback.
# #
# We need two try statements here so that # We need two try statements here so that
@ -992,6 +993,11 @@ class IOStream(BaseIOStream):
""" """
self._connecting = True self._connecting = True
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
try: try:
self.socket.connect(address) self.socket.connect(address)
except socket.error as e: except socket.error as e:
@ -1007,12 +1013,7 @@ class IOStream(BaseIOStream):
gen_log.warning("Connect error on fd %s: %s", gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e) self.socket.fileno(), e)
self.close(exc_info=True) self.close(exc_info=True)
return return future
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
self._add_io_state(self.io_loop.WRITE) self._add_io_state(self.io_loop.WRITE)
return future return future
@ -1184,8 +1185,14 @@ class SSLIOStream(IOStream):
return self.close(exc_info=True) return self.close(exc_info=True)
raise raise
except socket.error as err: except socket.error as err:
if err.args[0] in _ERRNO_CONNRESET: # Some port scans (e.g. nmap in -sT mode) have been known
# to cause do_handshake to raise EBADF, so make that error
# quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
if (err.args[0] in _ERRNO_CONNRESET or
err.args[0] == errno.EBADF):
return self.close(exc_info=True) return self.close(exc_info=True)
raise
except AttributeError: except AttributeError:
# On Linux, if the connection was reset before the call to # On Linux, if the connection was reset before the call to
# wrap_socket, do_handshake will fail with an # wrap_socket, do_handshake will fail with an

View File

@ -179,7 +179,7 @@ class LogFormatter(logging.Formatter):
def enable_pretty_logging(options=None, logger=None): def enable_pretty_logging(options=None, logger=None):
"""Turns on formatted logging output as configured. """Turns on formatted logging output as configured.
This is called automaticaly by `tornado.options.parse_command_line` This is called automatically by `tornado.options.parse_command_line`
and `tornado.options.parse_config_file`. and `tornado.options.parse_config_file`.
""" """
if options is None: if options is None:

View File

@ -35,6 +35,11 @@ except ImportError:
# ssl is not available on Google App Engine # ssl is not available on Google App Engine
ssl = None ssl = None
try:
xrange # py2
except NameError:
xrange = range # py3
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+ if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError SSLCertificateError = ssl.CertificateError
@ -60,8 +65,11 @@ _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"): if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None): def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
backlog=_DEFAULT_BACKLOG, flags=None):
"""Creates listening sockets bound to the given port and address. """Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if Returns a list of socket objects (multiple sockets are returned if
@ -141,7 +149,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
return sockets return sockets
if hasattr(socket, 'AF_UNIX'): if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=128): def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
"""Creates a listening unix socket. """Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted. If a socket with the given name already exists, it will be deleted.
@ -184,7 +192,18 @@ def add_accept_handler(sock, callback, io_loop=None):
io_loop = IOLoop.current() io_loop = IOLoop.current()
def accept_handler(fd, events): def accept_handler(fd, events):
while True: # More connections may come in while we're handling callbacks;
# to prevent starvation of other tasks we must limit the number
# of connections we accept at a time. Ideally we would accept
# up to the number of connections that were waiting when we
# entered this method, but this information is not available
# (and rearranging this method to call accept() as many times
# as possible before running any callbacks would have adverse
# effects on load balancing in multiprocess configurations).
# Instead, we use the (default) listen backlog as a rough
# heuristic for the number of connections we can reasonably
# accept at once.
for i in xrange(_DEFAULT_BACKLOG):
try: try:
connection, address = sock.accept() connection, address = sock.accept()
except socket.error as e: except socket.error as e:

View File

@ -79,7 +79,7 @@ import sys
import os import os
import textwrap import textwrap
from tornado.escape import _unicode from tornado.escape import _unicode, native_str
from tornado.log import define_logging_options from tornado.log import define_logging_options
from tornado import stack_context from tornado import stack_context
from tornado.util import basestring_type, exec_in from tornado.util import basestring_type, exec_in
@ -271,10 +271,14 @@ class OptionParser(object):
If ``final`` is ``False``, parse callbacks will not be run. If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations This is useful for applications that wish to combine configurations
from multiple sources. from multiple sources.
.. versionchanged:: 4.1
Config files are now always interpreted as utf-8 instead of
the system default encoding.
""" """
config = {} config = {}
with open(path) as f: with open(path, 'rb') as f:
exec_in(f.read(), config, config) exec_in(native_str(f.read()), config, config)
for name in config: for name in config:
if name in self._options: if name in self._options:
self._options[name].set(config[name]) self._options[name].set(config[name])

View File

@ -10,12 +10,10 @@ unfinished callbacks on the event loop that fail when it resumes)
""" """
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools import functools
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado import stack_context from tornado import stack_context
from tornado.util import timedelta_to_seconds
try: try:
# Import the real asyncio module for py33+ first. Older versions of the # Import the real asyncio module for py33+ first. Older versions of the

View File

@ -141,7 +141,7 @@ class TornadoDelayedCall(object):
class TornadoReactor(PosixReactorBase): class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop. """Twisted reactor built on the Tornado IOLoop.
Since it is intented to be used in applications where the top-level Since it is intended to be used in applications where the top-level
event loop is ``io_loop.start()`` rather than ``reactor.run()``, event loop is ``io_loop.start()`` rather than ``reactor.run()``,
it is implemented a little differently than other Twisted reactors. it is implemented a little differently than other Twisted reactors.
We override `mainLoop` instead of `doIteration` and must implement We override `mainLoop` instead of `doIteration` and must implement

View File

@ -39,7 +39,7 @@ from tornado.util import errno_from_exception
try: try:
import multiprocessing import multiprocessing
except ImportError: except ImportError:
# Multiprocessing is not availble on Google App Engine. # Multiprocessing is not available on Google App Engine.
multiprocessing = None multiprocessing = None
try: try:
@ -240,7 +240,7 @@ class Subprocess(object):
The callback takes one argument, the return code of the process. The callback takes one argument, the return code of the process.
This method uses a ``SIGCHILD`` handler, which is a global setting This method uses a ``SIGCHLD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate be necessary to call `Subprocess.initialize` first to designate
@ -257,7 +257,7 @@ class Subprocess(object):
@classmethod @classmethod
def initialize(cls, io_loop=None): def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHILD`` handler. """Initializes the ``SIGCHLD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues. The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the Note that the `.IOLoop` used for signal handling need not be the
@ -275,7 +275,7 @@ class Subprocess(object):
@classmethod @classmethod
def uninitialize(cls): def uninitialize(cls):
"""Removes the ``SIGCHILD`` handler.""" """Removes the ``SIGCHLD`` handler."""
if not cls._initialized: if not cls._initialized:
return return
signal.signal(signal.SIGCHLD, cls._old_sigchld) signal.signal(signal.SIGCHLD, cls._old_sigchld)

View File

@ -1,77 +0,0 @@
'''
Created on May 31, 2014
@author: Fenriswolf
'''
import logging
import inspect
import re
from collections import OrderedDict
from tornado.web import Application, RequestHandler, HTTPError
app_routing_log = logging.getLogger("tornado.application.routing")
class RoutingApplication(Application):
def __init__(self, handlers=None, default_host="", transforms=None, wsgi=False, **settings):
Application.__init__(self, handlers, default_host, transforms, wsgi, **settings)
self.handler_map = OrderedDict()
def expose(self, rule='', methods = ['GET'], kwargs=None, name=None):
"""
A decorator that is used to register a given URL rule.
"""
def decorator(func, *args, **kwargs):
func_name = func.__name__
frm = inspect.stack()[1]
class_name = frm[3]
module_name = frm[0].f_back.f_globals["__name__"]
full_class_name = module_name + '.' + class_name
for method in methods:
func_rule = rule if rule else None
if not func_rule:
if func_name == 'index':
func_rule = class_name
else:
func_rule = class_name + '/' + func_name
func_rule = r'/%s(.*)(/?)' % func_rule
if full_class_name not in self.handler_map:
self.handler_map.setdefault(full_class_name, {})[method] = [(func_rule, func_name)]
else:
self.handler_map[full_class_name][method] += [(func_rule, func_name)]
app_routing_log.info("register %s %s to %s.%s" % (method, func_rule, full_class_name, func_name))
return func
return decorator
def setRouteHandlers(self):
handlers = [(rule[0], full_class_name)
for full_class_name, methods in self.handler_map.items()
for rules in methods.values()
for rule in rules]
self.add_handlers(".*$", handlers)
class RequestRoutingHandler(RequestHandler):
def _get_func_name(self):
full_class_name = self.__module__ + '.' + self.__class__.__name__
rules = self.application.handler_map.get(full_class_name, {}).get(self.request.method, [])
for rule, func_name in rules:
if not rule or not func_name:
continue
match = re.match(rule, self.request.path)
if match:
return func_name
raise HTTPError(404, "")
def _execute_method(self):
if not self._finished:
func_name = self._get_func_name()
method = getattr(self, func_name)
self._when_complete(method(*self.path_args, **self.path_kwargs),
self._execute_finish)

View File

@ -19,11 +19,8 @@ import functools
import re import re
import socket import socket
import sys import sys
from io import BytesIO
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try: try:
import urlparse # py2 import urlparse # py2
@ -37,7 +34,7 @@ except ImportError:
ssl = None ssl = None
try: try:
import certifi import lib.certifi
except ImportError: except ImportError:
certifi = None certifi = None
@ -222,6 +219,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
stack_context.wrap(self._on_timeout)) stack_context.wrap(self._on_timeout))
self.tcp_client.connect(host, port, af=af, self.tcp_client.connect(host, port, af=af,
ssl_options=ssl_options, ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size,
callback=self._on_connect) callback=self._on_connect)
def _get_ssl_options(self, scheme): def _get_ssl_options(self, scheme):
@ -277,7 +275,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
stream.close() stream.close()
return return
self.stream = stream self.stream = stream
self.stream.set_close_callback(self._on_close) self.stream.set_close_callback(self.on_connection_close)
self._remove_timeout() self._remove_timeout()
if self.final_callback is None: if self.final_callback is None:
return return
@ -316,18 +314,18 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if self.request.user_agent: if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods: if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"): # Some HTTP methods nearly always have bodies while others
if (self.request.body is None and # almost never do. Fail in this case unless the user has
self.request.body_producer is None): # opted out of sanity checks with allow_nonstandard_methods.
raise AssertionError( body_expected = self.request.method in ("POST", "PATCH", "PUT")
'Body must not be empty for "%s" request' body_present = (self.request.body is not None or
% self.request.method) self.request.body_producer is not None)
else: if ((body_expected and not body_present) or
if (self.request.body is not None or (body_present and not body_expected)):
self.request.body_producer is not None): raise ValueError(
raise AssertionError( 'Body must %sbe None for method %s (unelss '
'Body must be empty for "%s" request' 'allow_nonstandard_methods is true)' %
% self.request.method) ('not ' if body_expected else '', self.request.method))
if self.request.expect_100_continue: if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue" self.request.headers["Expect"] = "100-continue"
if self.request.body is not None: if self.request.body is not None:
@ -338,7 +336,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if (self.request.method == "POST" and if (self.request.method == "POST" and
"Content-Type" not in self.request.headers): "Content-Type" not in self.request.headers):
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
if self.request.use_gzip: if self.request.decompress_response:
self.request.headers["Accept-Encoding"] = "gzip" self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') + req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else '')) (('?' + self.parsed.query) if self.parsed.query else ''))
@ -348,7 +346,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
HTTP1ConnectionParameters( HTTP1ConnectionParameters(
no_keep_alive=True, no_keep_alive=True,
max_header_size=self.max_header_size, max_header_size=self.max_header_size,
use_gzip=self.request.use_gzip), decompress=self.request.decompress_response),
self._sockaddr) self._sockaddr)
start_line = httputil.RequestStartLine(self.request.method, start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1') req_path, 'HTTP/1.1')
@ -418,12 +416,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
# pass it along, unless it's just the stream being closed. # pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError) return isinstance(value, StreamClosedError)
def _on_close(self): def on_connection_close(self):
if self.final_callback is not None: if self.final_callback is not None:
message = "Connection closed" message = "Connection closed"
if self.stream.error: if self.stream.error:
raise self.stream.error raise self.stream.error
raise HTTPError(599, message) try:
raise HTTPError(599, message)
except HTTPError:
self._handle_exception(*sys.exc_info())
def headers_received(self, first_line, headers): def headers_received(self, first_line, headers):
if self.request.expect_100_continue and first_line.code == 100: if self.request.expect_100_continue and first_line.code == 100:
@ -433,20 +434,6 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.code = first_line.code self.code = first_line.code
self.reason = first_line.reason self.reason = first_line.reason
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', self.headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise ValueError("Multiple unequal Content-Lengths: %r" %
self.headers["Content-Length"])
self.headers["Content-Length"] = pieces[0]
content_length = int(self.headers["Content-Length"])
else:
content_length = None
if self.request.header_callback is not None: if self.request.header_callback is not None:
# Reassemble the start line. # Reassemble the start line.
self.request.header_callback('%s %s %s\r\n' % first_line) self.request.header_callback('%s %s %s\r\n' % first_line)
@ -454,14 +441,6 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.request.header_callback("%s: %s\r\n" % (k, v)) self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n') self.request.header_callback('\r\n')
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in self.headers or
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
def finish(self): def finish(self):
data = b''.join(self.chunks) data = b''.join(self.chunks)
self._remove_timeout() self._remove_timeout()

View File

@ -41,13 +41,13 @@ Example usage::
sys.exit(1) sys.exit(1)
with StackContext(die_on_error): with StackContext(die_on_error):
# Any exception thrown here *or in callback and its desendents* # Any exception thrown here *or in callback and its descendants*
# will cause the process to exit instead of spinning endlessly # will cause the process to exit instead of spinning endlessly
# in the ioloop. # in the ioloop.
http_client.fetch(url, callback) http_client.fetch(url, callback)
ioloop.start() ioloop.start()
Most applications shouln't have to work with `StackContext` directly. Most applications shouldn't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary: Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a * If you're writing an asynchronous library that doesn't rely on a

View File

@ -163,7 +163,7 @@ class TCPClient(object):
functools.partial(self._create_stream, max_buffer_size)) functools.partial(self._create_stream, max_buffer_size))
af, addr, stream = yield connector.start() af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr) # TODO: For better performance we could cache the (af, addr)
# information here and re-use it on sbusequent connections to # information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None: if ssl_options is not None:
stream = yield stream.start_tls(False, ssl_options=ssl_options, stream = yield stream.start_tls(False, ssl_options=ssl_options,

View File

@ -199,7 +199,7 @@ import threading
from tornado import escape from tornado import escape
from tornado.log import app_log from tornado.log import app_log
from tornado.util import bytes_type, ObjectDict, exec_in, unicode_type from tornado.util import ObjectDict, exec_in, unicode_type
try: try:
from cStringIO import StringIO # py2 from cStringIO import StringIO # py2
@ -261,7 +261,7 @@ class Template(object):
"linkify": escape.linkify, "linkify": escape.linkify,
"datetime": datetime, "datetime": datetime,
"_tt_utf8": escape.utf8, # for internal use "_tt_utf8": escape.utf8, # for internal use
"_tt_string_types": (unicode_type, bytes_type), "_tt_string_types": (unicode_type, bytes),
# __name__ and __loader__ allow the traceback mechanism to find # __name__ and __loader__ allow the traceback mechanism to find
# the generated source code. # the generated source code.
"__name__": self.name.replace('.', '_'), "__name__": self.name.replace('.', '_'),

View File

@ -1,4 +1,4 @@
Test coverage is almost non-existent, but it's a start. Be sure to Test coverage is almost non-existent, but it's a start. Be sure to
set PYTHONPATH apprioriately (generally to the root directory of your set PYTHONPATH appropriately (generally to the root directory of your
tornado checkout) when running tests to make sure you're getting the tornado checkout) when running tests to make sure you're getting the
version of the tornado package that you expect. version of the tornado package that you expect.

View File

@ -8,7 +8,8 @@ from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test from tornado.test import httpclient_test
from tornado.test.util import unittest from tornado.test.util import unittest
from tornado.web import Application, RequestHandler from tornado.web import Application, RequestHandler, URLSpec
try: try:
import pycurl import pycurl

View File

@ -4,8 +4,8 @@
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import tornado.escape import tornado.escape
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode, squeeze, recursive_unicode
from tornado.util import u, unicode_type, bytes_type from tornado.util import u, unicode_type
from tornado.test.util import unittest from tornado.test.util import unittest
linkify_tests = [ linkify_tests = [
@ -212,6 +212,22 @@ class EscapeTestCase(unittest.TestCase):
# convert automatically if they are utf8; on python 3 byte strings # convert automatically if they are utf8; on python 3 byte strings
# are not allowed. # are not allowed.
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9")) self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
if bytes_type is str: if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9")) self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9") self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(squeeze(u('sequences of whitespace chars'))
, u('sequences of whitespace chars'))
def test_recursive_unicode(self):
tests = {
'dict': {b"foo": b"bar"},
'list': [b"foo", b"bar"],
'tuple': (b"foo", b"bar"),
'bytes': b"foo"
}
self.assertEqual(recursive_unicode(tests['dict']), {u("foo"): u("bar")})
self.assertEqual(recursive_unicode(tests['list']), [u("foo"), u("bar")])
self.assertEqual(recursive_unicode(tests['tuple']), (u("foo"), u("bar")))
self.assertEqual(recursive_unicode(tests['bytes']), u("foo"))

View File

@ -8,6 +8,8 @@ from contextlib import closing
import functools import functools
import sys import sys
import threading import threading
import datetime
from io import BytesIO
from tornado.escape import utf8 from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
@ -19,13 +21,9 @@ from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest, skipOnTravis from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type from tornado.util import u
from tornado.web import Application, RequestHandler, url from tornado.web import Application, RequestHandler, url
from tornado.httputil import format_timestamp, HTTPHeaders
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO
class HelloWorldHandler(RequestHandler): class HelloWorldHandler(RequestHandler):
@ -41,6 +39,18 @@ class PostHandler(RequestHandler):
self.get_argument("arg1"), self.get_argument("arg2"))) self.get_argument("arg1"), self.get_argument("arg2")))
class PutHandler(RequestHandler):
def put(self):
self.write("Put body: ")
self.write(self.request.body)
class RedirectHandler(RequestHandler):
def prepare(self):
self.redirect(self.get_argument("url"),
status=int(self.get_argument("status", "302")))
class ChunkHandler(RequestHandler): class ChunkHandler(RequestHandler):
def get(self): def get(self):
self.write("asdf") self.write("asdf")
@ -83,6 +93,13 @@ class ContentLength304Handler(RequestHandler):
pass pass
class PatchHandler(RequestHandler):
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler): class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',) SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
@ -101,6 +118,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
return Application([ return Application([
url("/hello", HelloWorldHandler), url("/hello", HelloWorldHandler),
url("/post", PostHandler), url("/post", PostHandler),
url("/put", PutHandler),
url("/redirect", RedirectHandler),
url("/chunk", ChunkHandler), url("/chunk", ChunkHandler),
url("/auth", AuthHandler), url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
@ -108,8 +127,15 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
url("/user_agent", UserAgentHandler), url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler), url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler), url("/all_methods", AllMethodsHandler),
url('/patch', PatchHandler),
], gzip=True) ], gzip=True)
def test_patch_receives_payload(self):
body = b"some patch data"
response = self.fetch("/patch", method='PATCH', body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@skipOnTravis @skipOnTravis
def test_hello_world(self): def test_hello_world(self):
response = self.fetch("/hello") response = self.fetch("/hello")
@ -263,7 +289,7 @@ Transfer-Encoding: chunked
def test_types(self): def test_types(self):
response = self.fetch("/hello") response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type) self.assertEqual(type(response.body), bytes)
self.assertEqual(type(response.headers["Content-Type"]), str) self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int) self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str) self.assertEqual(type(response.effective_url), str)
@ -314,10 +340,27 @@ Transfer-Encoding: chunked
# Construct a new instance of the configured client class # Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True, client = self.http_client.__class__(self.io_loop, force_instance=True,
defaults=defaults) defaults=defaults)
client.fetch(self.get_url('/user_agent'), callback=self.stop) try:
response = self.wait() client.fetch(self.get_url('/user_agent'), callback=self.stop)
self.assertEqual(response.body, b'TestDefaultUserAgent') response = self.wait()
client.close() self.assertEqual(response.body, b'TestDefaultUserAgent')
finally:
client.close()
def test_header_types(self):
# Header values may be passed as character or utf8 byte strings,
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u("MyUserAgent"), b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers['User-Agent'] = value
resp = self.fetch('/user_agent', headers=headers)
self.assertEqual(
resp.body, b"MyUserAgent",
"response=%r, value=%r, container=%r" %
(resp.body, value, container))
def test_304_with_content_length(self): def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include # According to the spec 304 responses SHOULD NOT include
@ -388,17 +431,39 @@ Transfer-Encoding: chunked
self.assertEqual(response.body, b'OTHER') self.assertEqual(response.body, b'OTHER')
@gen_test @gen_test
def test_body(self): def test_body_sanity_checks(self):
hello_url = self.get_url('/hello') hello_url = self.get_url('/hello')
with self.assertRaises(AssertionError) as context: with self.assertRaises(ValueError) as context:
yield self.http_client.fetch(hello_url, body='data') yield self.http_client.fetch(hello_url, body='data')
self.assertTrue('must be empty' in str(context.exception)) self.assertTrue('must be None' in str(context.exception))
with self.assertRaises(AssertionError) as context: with self.assertRaises(ValueError) as context:
yield self.http_client.fetch(hello_url, method='POST') yield self.http_client.fetch(hello_url, method='POST')
self.assertTrue('must not be empty' in str(context.exception)) self.assertTrue('must not be None' in str(context.exception))
# This test causes odd failures with the combination of
# curl_httpclient (at least with the version of libcurl available
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
# curl decides the response came back too soon and closes the connection
# to start again. It does this *before* telling the socket callback to
# unregister the FD. Some IOLoop implementations have special kernel
# integration to discover this immediately. Tornado's IOLoops
# ignore errors on remove_handler to accommodate this behavior, but
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
#def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
response = self.fetch("/redirect?status=307&url=/put",
method="PUT", body=b"hello")
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
class RequestProxyTest(unittest.TestCase): class RequestProxyTest(unittest.TestCase):
@ -515,3 +580,9 @@ class HTTPRequestTestCase(unittest.TestCase):
request = HTTPRequest('http://example.com') request = HTTPRequest('http://example.com')
request.body = 'foo' request.body = 'foo'
self.assertEqual(request.body, utf8('foo')) self.assertEqual(request.body, utf8('foo'))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest('http://example.com', if_modified_since=http_date)
self.assertEqual(request.headers,
{'If-Modified-Since': format_timestamp(http_date)})

View File

@ -9,12 +9,12 @@ from tornado.http1connection import HTTP1Connection
from tornado.httpserver import HTTPServer from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.iostream import IOStream from tornado.iostream import IOStream
from tornado.log import gen_log, app_log from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type from tornado.util import u
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
from contextlib import closing from contextlib import closing
import datetime import datetime
@ -25,11 +25,7 @@ import socket
import ssl import ssl
import sys import sys
import tempfile import tempfile
from io import BytesIO
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
def read_stream_body(stream, callback): def read_stream_body(stream, callback):
@ -297,10 +293,10 @@ class TypeCheckHandler(RequestHandler):
# secure cookies # secure cookies
self.check_type('arg_key', list(self.request.arguments.keys())[0], str) self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes_type) self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes)
def post(self): def post(self):
self.check_type('body', self.request.body, bytes_type) self.check_type('body', self.request.body, bytes)
self.write(self.errors) self.write(self.errors)
def get(self): def get(self):
@ -358,7 +354,7 @@ class HTTPServerTest(AsyncHTTPTestCase):
# if the data is not utf8. On python 2 parse_qs will work, # if the data is not utf8. On python 2 parse_qs will work,
# but then the recursive_unicode call in EchoHandler will # but then the recursive_unicode call in EchoHandler will
# fail. # fail.
if str is bytes_type: if str is bytes:
return return
with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'): with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
response = self.fetch( response = self.fetch(
@ -586,6 +582,8 @@ class KeepAliveTest(AsyncHTTPTestCase):
class HelloHandler(RequestHandler): class HelloHandler(RequestHandler):
def get(self): def get(self):
self.finish('Hello world') self.finish('Hello world')
def post(self):
self.finish('Hello world')
class LargeHandler(RequestHandler): class LargeHandler(RequestHandler):
def get(self): def get(self):
@ -687,6 +685,17 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close() self.close()
def test_http10_keepalive_extra_crlf(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_pipelined_requests(self): def test_pipelined_requests(self):
self.connect() self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
@ -715,6 +724,19 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.read_headers() self.read_headers()
self.close() self.close()
def test_keepalive_chunked(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'POST / HTTP/1.0\r\nConnection: keep-alive\r\n'
b'Transfer-Encoding: chunked\r\n'
b'\r\n0\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
class GzipBaseTest(object): class GzipBaseTest(object):
def get_app(self): def get_app(self):
@ -736,7 +758,7 @@ class GzipBaseTest(object):
class GzipTest(GzipBaseTest, AsyncHTTPTestCase): class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
def get_httpserver_options(self): def get_httpserver_options(self):
return dict(gzip=True) return dict(decompress_request=True)
def test_gzip(self): def test_gzip(self):
response = self.post_gzip('foo=bar') response = self.post_gzip('foo=bar')
@ -764,7 +786,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
return SimpleAsyncHTTPClient(io_loop=self.io_loop) return SimpleAsyncHTTPClient(io_loop=self.io_loop)
def get_httpserver_options(self): def get_httpserver_options(self):
return dict(chunk_size=self.CHUNK_SIZE, gzip=True) return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
class MessageDelegate(HTTPMessageDelegate): class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection): def __init__(self, connection):

View File

@ -2,7 +2,7 @@
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
from tornado.escape import utf8 from tornado.escape import utf8
from tornado.log import gen_log from tornado.log import gen_log
from tornado.testing import ExpectLog from tornado.testing import ExpectLog
@ -253,3 +253,26 @@ class FormatTimestampTest(unittest.TestCase):
def test_datetime(self): def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP)) self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
# HTTPServerRequest is mainly tested incidentally to the server itself,
# but this tests the parts of the class that can be tested in isolation.
class HTTPServerRequestTest(unittest.TestCase):
def test_default_constructor(self):
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
HTTPServerRequest(uri='/')
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
PATH = "/foo"
VERSION = "HTTP/1.1"
def test_parse_request_start_line(self):
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
parsed_start_line = parse_request_start_line(start_line)
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)

View File

@ -173,6 +173,25 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop)) self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait() self.wait()
def test_remove_timeout_from_timeout(self):
calls = [False, False]
# Schedule several callbacks and wait for them all to come due at once.
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
self.wait()
self.assertEqual(calls, [True, False])
def test_timeout_with_arguments(self): def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly. # This tests that all the timeout methods pass through *args correctly.
results = [] results = []
@ -185,6 +204,23 @@ class TestIOLoop(AsyncTestCase):
self.wait() self.wait()
self.assertEqual(results, [1, 2, 3, 4]) self.assertEqual(results, [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
# passed to remove_timeout.
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_at_return(self):
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_later_return(self):
handle = self.io_loop.call_later(0, lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_close_file_object(self): def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor, """When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True), the object should be closed (by IOLoop.close(all_fds=True),
@ -298,6 +334,33 @@ class TestIOLoop(AsyncTestCase):
with ExpectLog(app_log, "Exception in callback"): with ExpectLog(app_log, "Exception in callback"):
self.wait() self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
client.send(b'abc')
server.send(b'abc')
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.01, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
self.assertEqual(chunks, [b'abc'])
finally:
client.close()
server.close()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't # Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current. # automatically set as current.

View File

@ -10,7 +10,7 @@ from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application from tornado.web import RequestHandler, Application
import certifi import lib.certifi
import errno import errno
import logging import logging
import os import os
@ -511,7 +511,7 @@ class TestIOStreamMixin(object):
server, client = self.make_iostream_pair() server, client = self.make_iostream_pair()
server.set_close_callback(self.stop) server.set_close_callback(self.stop)
try: try:
# Start a read that will be fullfilled asynchronously. # Start a read that will be fulfilled asynchronously.
server.read_bytes(1, lambda data: None) server.read_bytes(1, lambda data: None)
client.write(b'a') client.write(b'a')
# Stub out read_from_fd to make it fail. # Stub out read_from_fd to make it fail.

View File

@ -57,3 +57,43 @@ class EnglishTest(unittest.TestCase):
date = datetime.datetime(2013, 4, 28, 18, 35) date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_date(date, full_format=True), self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm') 'April 28, 2013 at 6:35 pm')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
'2 seconds ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False),
'2 minutes ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(hours=2), full_format=False),
'2 hours ago')
now = datetime.datetime.utcnow()
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1), full_format=False, shorter=True),
'yesterday')
date = now - datetime.timedelta(days=2)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
locale._weekdays[date.weekday()])
date = now - datetime.timedelta(days=300)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d' % (locale._months[date.month - 1], date.day))
date = now - datetime.timedelta(days=500)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d, %d' % (locale._months[date.month - 1], date.day, date.year))
def test_friendly_number(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.friendly_number(1000000), '1,000,000')
def test_list(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.list([]), '')
self.assertEqual(locale.list(['A']), 'A')
self.assertEqual(locale.list(['A', 'B']), 'A and B')
self.assertEqual(locale.list(['A', 'B', 'C']), 'A, B and C')
def test_format_day(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_day(date=date, dow=True), 'Sunday, April 28')
self.assertEqual(locale.format_day(date=date, dow=False), 'April 28')

View File

@ -29,7 +29,7 @@ from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser from tornado.options import OptionParser
from tornado.test.util import unittest from tornado.test.util import unittest
from tornado.util import u, bytes_type, basestring_type from tornado.util import u, basestring_type
@contextlib.contextmanager @contextlib.contextmanager
@ -95,8 +95,9 @@ class LogFormatterTest(unittest.TestCase):
self.assertEqual(self.get_output(), utf8(repr(b"\xe9"))) self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self): def test_utf8_logging(self):
self.logger.error(u("\u00e9").encode("utf8")) with ignore_bytes_warning():
if issubclass(bytes_type, basestring_type): self.logger.error(u("\u00e9").encode("utf8"))
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte # on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is. # strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u("\u00e9"))) self.assertEqual(self.get_output(), utf8(u("\u00e9")))

View File

@ -34,15 +34,6 @@ else:
class _ResolverTestMixin(object): class _ResolverTestMixin(object):
def skipOnCares(self):
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
if self.resolver.__class__.__name__ == 'CaresResolver':
self.skipTest("CaresResolver doesn't recognize fake NXDOMAIN")
def test_localhost(self): def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop) self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait() result = self.wait()
@ -55,8 +46,11 @@ class _ResolverTestMixin(object):
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo) addrinfo)
# It is impossible to quickly and consistently generate an error in name
# resolution, so test this case separately, using mocks as needed.
class _ResolverErrorTestMixin(object):
def test_bad_host(self): def test_bad_host(self):
self.skipOnCares()
def handler(exc_typ, exc_val, exc_tb): def handler(exc_typ, exc_val, exc_tb):
self.stop(exc_val) self.stop(exc_val)
return True # Halt propagation. return True # Halt propagation.
@ -69,11 +63,13 @@ class _ResolverTestMixin(object):
@gen_test @gen_test
def test_future_interface_bad_host(self): def test_future_interface_bad_host(self):
self.skipOnCares()
with self.assertRaises(Exception): with self.assertRaises(Exception):
yield self.resolver.resolve('an invalid domain', 80, yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC) socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror("mock: lookup failed")
@skipIfNoNetwork @skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin): class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -82,6 +78,21 @@ class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = BlockingResolver(io_loop=self.io_loop) self.resolver = BlockingResolver(io_loop=self.io_loop)
# getaddrinfo-based tests need mocking to reliably generate errors;
# some configurations are slow to produce errors and take longer than
# our default timeout.
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(BlockingResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(BlockingResolverErrorTest, self).tearDown()
@skipIfNoNetwork @skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present") @unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin): class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -94,6 +105,18 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
super(ThreadedResolverTest, self).tearDown() super(ThreadedResolverTest, self).tearDown()
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(ThreadedResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(ThreadedResolverErrorTest, self).tearDown()
@skipIfNoNetwork @skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present") @unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32") @unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
@ -121,6 +144,12 @@ class ThreadedResolverImportTest(unittest.TestCase):
self.fail("import timed out") self.fail("import timed out")
# We do not test errors with CaresResolver:
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
@skipIfNoNetwork @skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present") @unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin): class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -129,10 +158,13 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = CaresResolver(io_loop=self.io_loop) self.resolver = CaresResolver(io_loop=self.io_loop)
# TwistedResolver produces consistent errors in our test cases so we
# can test the regular and error cases in the same class.
@skipIfNoNetwork @skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present") @unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted") @unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin): class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin,
_ResolverErrorTestMixin):
def setUp(self): def setUp(self):
super(TwistedResolverTest, self).setUp() super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver(io_loop=self.io_loop) self.resolver = TwistedResolver(io_loop=self.io_loop)

View File

@ -1,2 +1,3 @@
port=443 port=443
port=443 port=443
username='李康'

View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import datetime import datetime
@ -32,9 +33,11 @@ class OptionsTest(unittest.TestCase):
def test_parse_config_file(self): def test_parse_config_file(self):
options = OptionParser() options = OptionParser()
options.define("port", default=80) options.define("port", default=80)
options.define("username", default='foo')
options.parse_config_file(os.path.join(os.path.dirname(__file__), options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg")) "options_test.cfg"))
self.assertEquals(options.port, 443) self.assertEquals(options.port, 443)
self.assertEqual(options.username, "李康")
def test_parse_callbacks(self): def test_parse_callbacks(self):
options = OptionParser() options = OptionParser()

View File

@ -14,7 +14,7 @@ from tornado import gen
from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.log import gen_log, app_log from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
@ -294,10 +294,13 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(response.code, 204) self.assertEqual(response.code, 204)
# 204 status doesn't need a content-length, but tornado will # 204 status doesn't need a content-length, but tornado will
# add a zero content-length anyway. # add a zero content-length anyway.
#
# A test without a content-length header is included below
# in HTTP204NoContentTestCase.
self.assertEqual(response.headers["Content-length"], "0") self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed # 204 status with non-zero content length is malformed
with ExpectLog(app_log, "Uncaught exception"): with ExpectLog(gen_log, "Malformed HTTP message"):
response = self.fetch("/no_content?error=1") response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599) self.assertEqual(response.code, 599)
@ -476,6 +479,27 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
self.assertEqual(res.body, b'A') self.assertEqual(res.body, b'A')
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). Tornado always
# sends a content-length, so we simulate here a server that sends
# no content length and does not close the connection.
#
# Tests of a 204 response with a Content-Length header are included
# in SimpleHTTPClientTestMixin.
request.connection.stream.write(
b"HTTP/1.1 204 No content\r\n\r\n")
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch('/')
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
class HostnameMappingTestCase(AsyncHTTPTestCase): class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self): def setUp(self):
super(HostnameMappingTestCase, self).setUp() super(HostnameMappingTestCase, self).setUp()

View File

@ -72,7 +72,9 @@ class TCPClientTest(AsyncTestCase):
super(TCPClientTest, self).tearDown() super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self): def skipIfLocalhostV4(self):
Resolver().resolve('localhost', 0, callback=self.stop) # The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
Resolver().resolve('localhost', 80, callback=self.stop)
addrinfo = self.wait() addrinfo = self.wait()
families = set(addr[0] for addr in addrinfo) families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families: if socket.AF_INET6 not in families:

View File

@ -7,7 +7,7 @@ import traceback
from tornado.escape import utf8, native_str, to_unicode from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.test.util import unittest from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type from tornado.util import u, ObjectDict, unicode_type
class TemplateTest(unittest.TestCase): class TemplateTest(unittest.TestCase):
@ -374,7 +374,7 @@ raw: {% raw name %}""",
"{% autoescape py_escape %}s = {{ name }}\n"}) "{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s): def py_escape(s):
self.assertEqual(type(s), bytes_type) self.assertEqual(type(s), bytes)
return repr(native_str(s)) return repr(native_str(s))
def render(template, name): def render(template, name):

View File

@ -3,7 +3,8 @@
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen, ioloop from tornado import gen, ioloop
from tornado.testing import AsyncTestCase, gen_test from tornado.log import app_log
from tornado.testing import AsyncTestCase, gen_test, ExpectLog
from tornado.test.util import unittest from tornado.test.util import unittest
import contextlib import contextlib
@ -13,7 +14,7 @@ import traceback
@contextlib.contextmanager @contextlib.contextmanager
def set_environ(name, value): def set_environ(name, value):
old_value = os.environ.get('name') old_value = os.environ.get(name)
os.environ[name] = value os.environ[name] = value
try: try:
@ -62,6 +63,17 @@ class AsyncTestCaseTest(AsyncTestCase):
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop) self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
self.wait(timeout=0.15) self.wait(timeout=0.15)
def test_multiple_errors(self):
def fail(message):
raise Exception(message)
self.io_loop.add_callback(lambda: fail("error one"))
self.io_loop.add_callback(lambda: fail("error two"))
# The first error gets raised; the second gets logged.
with ExpectLog(app_log, "multiple unhandled exceptions"):
with self.assertRaises(Exception) as cm:
self.wait()
self.assertEqual(str(cm.exception), "error one")
class AsyncTestCaseWrapperTest(unittest.TestCase): class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self): def test_undecorated_generator(self):

View File

@ -1,9 +1,10 @@
# coding: utf-8 # coding: utf-8
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import sys import sys
import datetime
from tornado.escape import utf8 from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds
from tornado.test.util import unittest from tornado.test.util import unittest
try: try:
@ -170,3 +171,9 @@ class ArgReplacerTest(unittest.TestCase):
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old') self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs), self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3))) ('old', (1,), dict(y=2, callback='new', z=3)))
class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)

View File

@ -4,13 +4,14 @@ from tornado import gen
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
from tornado.httputil import format_timestamp from tornado.httputil import format_timestamp
from tornado.iostream import IOStream from tornado.iostream import IOStream
from tornado import locale
from tornado.log import app_log, gen_log from tornado.log import app_log, gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
import binascii import binascii
import contextlib import contextlib
@ -21,7 +22,6 @@ import logging
import os import os
import re import re
import socket import socket
import sys
try: try:
import urllib.parse as urllib_parse # py3 import urllib.parse as urllib_parse # py3
@ -163,11 +163,21 @@ class CookieTest(WebTestCase):
# Attributes from the first call are not carried over. # Attributes from the first call are not carried over.
self.set_cookie("a", "e") self.set_cookie("a", "e")
class SetCookieMaxAgeHandler(RequestHandler):
def get(self):
self.set_cookie("foo", "bar", max_age=10)
class SetCookieExpiresDaysHandler(RequestHandler):
def get(self):
self.set_cookie("foo", "bar", expires_days=10)
return [("/set", SetCookieHandler), return [("/set", SetCookieHandler),
("/get", GetCookieHandler), ("/get", GetCookieHandler),
("/set_domain", SetCookieDomainHandler), ("/set_domain", SetCookieDomainHandler),
("/special_char", SetCookieSpecialCharHandler), ("/special_char", SetCookieSpecialCharHandler),
("/set_overwrite", SetCookieOverwriteHandler), ("/set_overwrite", SetCookieOverwriteHandler),
("/set_max_age", SetCookieMaxAgeHandler),
("/set_expires_days", SetCookieExpiresDaysHandler),
] ]
def test_set_cookie(self): def test_set_cookie(self):
@ -222,6 +232,23 @@ class CookieTest(WebTestCase):
self.assertEqual(sorted(headers), self.assertEqual(sorted(headers),
["a=e; Path=/", "c=d; Domain=example.com; Path=/"]) ["a=e; Path=/", "c=d; Domain=example.com; Path=/"])
def test_set_cookie_max_age(self):
response = self.fetch("/set_max_age")
headers = response.headers.get_list("Set-Cookie")
self.assertEqual(sorted(headers),
["foo=bar; Max-Age=10; Path=/"])
def test_set_cookie_expires_days(self):
response = self.fetch("/set_expires_days")
header = response.headers.get("Set-Cookie")
match = re.match("foo=bar; expires=(?P<expires>.+); Path=/", header)
self.assertIsNotNone(match)
expires = datetime.datetime.utcnow() + datetime.timedelta(days=10)
header_expires = datetime.datetime(
*email.utils.parsedate(match.groupdict()["expires"])[:6])
self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10)
class AuthRedirectRequestHandler(RequestHandler): class AuthRedirectRequestHandler(RequestHandler):
def initialize(self, login_url): def initialize(self, login_url):
@ -302,7 +329,7 @@ class EchoHandler(RequestHandler):
if type(key) != str: if type(key) != str:
raise Exception("incorrect type for key: %r" % type(key)) raise Exception("incorrect type for key: %r" % type(key))
for value in self.request.arguments[key]: for value in self.request.arguments[key]:
if type(value) != bytes_type: if type(value) != bytes:
raise Exception("incorrect type for value: %r" % raise Exception("incorrect type for value: %r" %
type(value)) type(value))
for value in self.get_arguments(key): for value in self.get_arguments(key):
@ -370,10 +397,10 @@ class TypeCheckHandler(RequestHandler):
if list(self.cookies.keys()) != ['asdf']: if list(self.cookies.keys()) != ['asdf']:
raise Exception("unexpected values for cookie keys: %r" % raise Exception("unexpected values for cookie keys: %r" %
self.cookies.keys()) self.cookies.keys())
self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes_type) self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes)
self.check_type('get_cookie', self.get_cookie('asdf'), str) self.check_type('get_cookie', self.get_cookie('asdf'), str)
self.check_type('xsrf_token', self.xsrf_token, bytes_type) self.check_type('xsrf_token', self.xsrf_token, bytes)
self.check_type('xsrf_form_html', self.xsrf_form_html(), str) self.check_type('xsrf_form_html', self.xsrf_form_html(), str)
self.check_type('reverse_url', self.reverse_url('typecheck', 'foo'), str) self.check_type('reverse_url', self.reverse_url('typecheck', 'foo'), str)
@ -399,7 +426,7 @@ class TypeCheckHandler(RequestHandler):
class DecodeArgHandler(RequestHandler): class DecodeArgHandler(RequestHandler):
def decode_argument(self, value, name=None): def decode_argument(self, value, name=None):
if type(value) != bytes_type: if type(value) != bytes:
raise Exception("unexpected type for value: %r" % type(value)) raise Exception("unexpected type for value: %r" % type(value))
# use self.request.arguments directly to avoid recursion # use self.request.arguments directly to avoid recursion
if 'encoding' in self.request.arguments: if 'encoding' in self.request.arguments:
@ -409,7 +436,7 @@ class DecodeArgHandler(RequestHandler):
def get(self, arg): def get(self, arg):
def describe(s): def describe(s):
if type(s) == bytes_type: if type(s) == bytes:
return ["bytes", native_str(binascii.b2a_hex(s))] return ["bytes", native_str(binascii.b2a_hex(s))]
elif type(s) == unicode_type: elif type(s) == unicode_type:
return ["unicode", s] return ["unicode", s]
@ -550,6 +577,8 @@ class WSGISafeWebTest(WebTestCase):
url("/optional_path/(.+)?", OptionalPathHandler), url("/optional_path/(.+)?", OptionalPathHandler),
url("/multi_header", MultiHeaderHandler), url("/multi_header", MultiHeaderHandler),
url("/redirect", RedirectHandler), url("/redirect", RedirectHandler),
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
url("/header_injection", HeaderInjectionHandler), url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler), url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler), url("/get_arguments", GetArgumentsHandler),
@ -675,6 +704,14 @@ js_embed()
response = self.fetch("/redirect?status=307", follow_redirects=False) response = self.fetch("/redirect?status=307", follow_redirects=False)
self.assertEqual(response.code, 307) self.assertEqual(response.code, 307)
def test_web_redirect(self):
response = self.fetch("/web_redirect_permanent", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
response = self.fetch("/web_redirect", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
def test_header_injection(self): def test_header_injection(self):
response = self.fetch("/header_injection") response = self.fetch("/header_injection")
self.assertEqual(response.body, b"ok") self.assertEqual(response.body, b"ok")
@ -1348,7 +1385,9 @@ class GzipTestCase(SimpleHandlerTestCase):
self.write('hello world') self.write('hello world')
def get_app_kwargs(self): def get_app_kwargs(self):
return dict(gzip=True) return dict(
gzip=True,
static_path=os.path.join(os.path.dirname(__file__), 'static'))
def test_gzip(self): def test_gzip(self):
response = self.fetch('/') response = self.fetch('/')
@ -1361,6 +1400,17 @@ class GzipTestCase(SimpleHandlerTestCase):
'gzip') 'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding') self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_static(self):
# The streaming responses in StaticFileHandler have subtle
# interactions with the gzip output so test this case separately.
response = self.fetch('/robots.txt')
self.assertEqual(
response.headers.get(
'Content-Encoding',
response.headers.get('X-Consumed-Content-Encoding')),
'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_not_requested(self): def test_gzip_not_requested(self):
response = self.fetch('/', use_gzip=False) response = self.fetch('/', use_gzip=False)
self.assertNotIn('Content-Encoding', response.headers) self.assertNotIn('Content-Encoding', response.headers)
@ -1554,19 +1604,26 @@ class MultipleExceptionTest(SimpleHandlerTestCase):
@wsgi_safe @wsgi_safe
class SetCurrentUserTest(SimpleHandlerTestCase): class SetLazyPropertiesTest(SimpleHandlerTestCase):
class Handler(RequestHandler): class Handler(RequestHandler):
def prepare(self): def prepare(self):
self.current_user = 'Ben' self.current_user = 'Ben'
self.locale = locale.get('en_US')
def get_user_locale(self):
raise NotImplementedError()
def get_current_user(self):
raise NotImplementedError()
def get(self): def get(self):
self.write('Hello %s' % self.current_user) self.write('Hello %s (%s)' % (self.current_user, self.locale.code))
def test_set_current_user(self): def test_set_properties(self):
# Ensure that current_user can be assigned to normally for apps # Ensure that current_user can be assigned to normally for apps
# that want to forgo the lazy get_current_user property # that want to forgo the lazy get_current_user property
response = self.fetch('/') response = self.fetch('/')
self.assertEqual(response.body, b'Hello Ben') self.assertEqual(response.body, b'Hello Ben (en_US)')
@wsgi_safe @wsgi_safe
@ -2193,6 +2250,20 @@ class XSRFTest(SimpleHandlerTestCase):
headers=self.cookie_headers()) headers=self.cookie_headers())
self.assertEqual(response.code, 403) self.assertEqual(response.code, 403)
def test_xsrf_success_short_token(self):
response = self.fetch(
"/", method="POST",
body=urllib_parse.urlencode(dict(_xsrf='deadbeef')),
headers=self.cookie_headers(token='deadbeef'))
self.assertEqual(response.code, 200)
def test_xsrf_success_non_hex_token(self):
response = self.fetch(
"/", method="POST",
body=urllib_parse.urlencode(dict(_xsrf='xoxo')),
headers=self.cookie_headers(token='xoxo'))
self.assertEqual(response.code, 200)
def test_xsrf_success_post_body(self): def test_xsrf_success_post_body(self):
response = self.fetch( response = self.fetch(
"/", method="POST", "/", method="POST",
@ -2299,3 +2370,38 @@ class FinishExceptionTest(SimpleHandlerTestCase):
self.assertEqual('Basic realm="something"', self.assertEqual('Basic realm="something"',
response.headers.get('WWW-Authenticate')) response.headers.get('WWW-Authenticate'))
self.assertEqual(b'authentication required', response.body) self.assertEqual(b'authentication required', response.body)
class DecoratorTest(WebTestCase):
def get_handlers(self):
class RemoveSlashHandler(RequestHandler):
@removeslash
def get(self):
pass
class AddSlashHandler(RequestHandler):
@addslash
def get(self):
pass
return [("/removeslash/", RemoveSlashHandler),
("/addslash", AddSlashHandler),
]
def test_removeslash(self):
response = self.fetch("/removeslash/", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/removeslash")
response = self.fetch("/removeslash/?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/removeslash?foo=bar")
def test_addslash(self):
response = self.fetch("/addslash", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/")
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")

View File

@ -3,11 +3,13 @@ from __future__ import absolute_import, division, print_function, with_statement
import traceback import traceback
from tornado.concurrent import Future from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest from tornado.test.util import unittest
from tornado.web import Application, RequestHandler from tornado.web import Application, RequestHandler
from tornado.util import u
try: try:
import tornado.websocket import tornado.websocket
@ -33,8 +35,12 @@ class TestWebSocketHandler(WebSocketHandler):
This allows for deterministic cleanup of the associated socket. This allows for deterministic cleanup of the associated socket.
""" """
def initialize(self, close_future): def initialize(self, close_future, compression_options=None):
self.close_future = close_future self.close_future = close_future
self.compression_options = compression_options
def get_compression_options(self):
return self.compression_options
def on_close(self): def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason)) self.close_future.set_result((self.close_code, self.close_reason))
@ -45,6 +51,11 @@ class EchoHandler(TestWebSocketHandler):
self.write_message(message, isinstance(message, bytes)) self.write_message(message, isinstance(message, bytes))
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1/0
class HeaderHandler(TestWebSocketHandler): class HeaderHandler(TestWebSocketHandler):
def open(self): def open(self):
try: try:
@ -67,7 +78,34 @@ class CloseReasonHandler(TestWebSocketHandler):
self.close(1001, "goodbye") self.close(1001, "goodbye")
class WebSocketTest(AsyncHTTPTestCase): class AsyncPrepareHandler(TestWebSocketHandler):
@gen.coroutine
def prepare(self):
yield gen.moment
def on_message(self, message):
self.write_message(message)
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, compression_options=None):
ws = yield websocket_connect(
'ws://localhost:%d%s' % (self.get_http_port(), path),
compression_options=compression_options)
raise gen.Return(ws)
@gen.coroutine
def close(self, ws):
"""Close a websocket connection and wait for the server side.
If we don't wait here, there are sometimes leak warnings in the
tests.
"""
ws.close()
yield self.close_future
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self): def get_app(self):
self.close_future = Future() self.close_future = Future()
return Application([ return Application([
@ -76,6 +114,10 @@ class WebSocketTest(AsyncHTTPTestCase):
('/header', HeaderHandler, dict(close_future=self.close_future)), ('/header', HeaderHandler, dict(close_future=self.close_future)),
('/close_reason', CloseReasonHandler, ('/close_reason', CloseReasonHandler,
dict(close_future=self.close_future)), dict(close_future=self.close_future)),
('/error_in_on_message', ErrorInOnMessageHandler,
dict(close_future=self.close_future)),
('/async_prepare', AsyncPrepareHandler,
dict(close_future=self.close_future)),
]) ])
def test_http_request(self): def test_http_request(self):
@ -85,14 +127,11 @@ class WebSocketTest(AsyncHTTPTestCase):
@gen_test @gen_test
def test_websocket_gen(self): def test_websocket_gen(self):
ws = yield websocket_connect( ws = yield self.ws_connect('/echo')
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop)
ws.write_message('hello') ws.write_message('hello')
response = yield ws.read_message() response = yield ws.read_message()
self.assertEqual(response, 'hello') self.assertEqual(response, 'hello')
ws.close() yield self.close(ws)
yield self.close_future
def test_websocket_callbacks(self): def test_websocket_callbacks(self):
websocket_connect( websocket_connect(
@ -107,20 +146,41 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.close() ws.close()
self.wait() self.wait()
@gen_test
def test_binary_message(self):
ws = yield self.ws_connect('/echo')
ws.write_message(b'hello \xe9', binary=True)
response = yield ws.read_message()
self.assertEqual(response, b'hello \xe9')
yield self.close(ws)
@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect('/echo')
ws.write_message(u('hello \u00e9'))
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
yield self.close(ws)
@gen_test
def test_error_in_on_message(self):
ws = yield self.ws_connect('/error_in_on_message')
ws.write_message('hello')
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
yield self.close(ws)
@gen_test @gen_test
def test_websocket_http_fail(self): def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm: with self.assertRaises(HTTPError) as cm:
yield websocket_connect( yield self.ws_connect('/notfound')
'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 404) self.assertEqual(cm.exception.code, 404)
@gen_test @gen_test
def test_websocket_http_success(self): def test_websocket_http_success(self):
with self.assertRaises(WebSocketError): with self.assertRaises(WebSocketError):
yield websocket_connect( yield self.ws_connect('/non_ws')
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
@gen_test @gen_test
def test_websocket_network_fail(self): def test_websocket_network_fail(self):
@ -139,6 +199,7 @@ class WebSocketTest(AsyncHTTPTestCase):
'ws://localhost:%d/echo' % self.get_http_port()) 'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello') ws.write_message('hello')
ws.write_message('world') ws.write_message('world')
# Close the underlying stream.
ws.stream.close() ws.stream.close()
yield self.close_future yield self.close_future
@ -150,13 +211,11 @@ class WebSocketTest(AsyncHTTPTestCase):
headers={'X-Test': 'hello'})) headers={'X-Test': 'hello'}))
response = yield ws.read_message() response = yield ws.read_message()
self.assertEqual(response, 'hello') self.assertEqual(response, 'hello')
ws.close() yield self.close(ws)
yield self.close_future
@gen_test @gen_test
def test_server_close_reason(self): def test_server_close_reason(self):
ws = yield websocket_connect( ws = yield self.ws_connect('/close_reason')
'ws://localhost:%d/close_reason' % self.get_http_port())
msg = yield ws.read_message() msg = yield ws.read_message()
# A message of None means the other side closed the connection. # A message of None means the other side closed the connection.
self.assertIs(msg, None) self.assertIs(msg, None)
@ -165,13 +224,21 @@ class WebSocketTest(AsyncHTTPTestCase):
@gen_test @gen_test
def test_client_close_reason(self): def test_client_close_reason(self):
ws = yield websocket_connect( ws = yield self.ws_connect('/echo')
'ws://localhost:%d/echo' % self.get_http_port())
ws.close(1001, 'goodbye') ws.close(1001, 'goodbye')
code, reason = yield self.close_future code, reason = yield self.close_future
self.assertEqual(code, 1001) self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye') self.assertEqual(reason, 'goodbye')
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
ws = yield self.ws_connect('/async_prepare')
ws.write_message('hello')
res = yield ws.read_message()
self.assertEqual(res, 'hello')
@gen_test @gen_test
def test_check_origin_valid_no_path(self): def test_check_origin_valid_no_path(self):
port = self.get_http_port() port = self.get_http_port()
@ -184,8 +251,7 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.write_message('hello') ws.write_message('hello')
response = yield ws.read_message() response = yield ws.read_message()
self.assertEqual(response, 'hello') self.assertEqual(response, 'hello')
ws.close() yield self.close(ws)
yield self.close_future
@gen_test @gen_test
def test_check_origin_valid_with_path(self): def test_check_origin_valid_with_path(self):
@ -199,8 +265,7 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.write_message('hello') ws.write_message('hello')
response = yield ws.read_message() response = yield ws.read_message()
self.assertEqual(response, 'hello') self.assertEqual(response, 'hello')
ws.close() yield self.close(ws)
yield self.close_future
@gen_test @gen_test
def test_check_origin_invalid_partial_url(self): def test_check_origin_invalid_partial_url(self):
@ -245,6 +310,78 @@ class WebSocketTest(AsyncHTTPTestCase):
self.assertEqual(cm.exception.code, 403) self.assertEqual(cm.exception.code, 403)
class CompressionTestMixin(object):
MESSAGE = 'Hello world. Testing 123 123'
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
])
def get_server_compression_options(self):
return None
def get_client_compression_options(self):
return None
@gen_test
def test_message_sizes(self):
ws = yield self.ws_connect(
'/echo',
compression_options=self.get_client_compression_options())
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
ws.write_message(self.MESSAGE)
response = yield ws.read_message()
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in,
ws.protocol._wire_bytes_out)
yield self.close(ws)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
def verify_wire_bytes(self, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
pass
# If only one side tries to compress, the extension is not negotiated.
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_client_compression_options(self):
return {}
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
def get_client_compression_options(self):
return {}
def verify_wire_bytes(self, bytes_in, bytes_out):
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
# Bytes out includes the 4 bytes mask key per message.
self.assertEqual(bytes_out, bytes_in + 12)
class MaskFunctionMixin(object): class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data) # Subclasses should define self.mask(mask, data)
def test_mask(self): def test_mask(self):

View File

@ -28,7 +28,7 @@ except ImportError:
IOLoop = None IOLoop = None
netutil = None netutil = None
SimpleAsyncHTTPClient = None SimpleAsyncHTTPClient = None
from tornado.log import gen_log from tornado.log import gen_log, app_log
from tornado.stack_context import ExceptionStackContext from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type from tornado.util import raise_exc_info, basestring_type
import functools import functools
@ -114,8 +114,8 @@ class _TestMethodWrapper(object):
def __init__(self, orig_method): def __init__(self, orig_method):
self.orig_method = orig_method self.orig_method = orig_method
def __call__(self): def __call__(self, *args, **kwargs):
result = self.orig_method() result = self.orig_method(*args, **kwargs)
if isinstance(result, types.GeneratorType): if isinstance(result, types.GeneratorType):
raise TypeError("Generator test methods should be decorated with " raise TypeError("Generator test methods should be decorated with "
"tornado.testing.gen_test") "tornado.testing.gen_test")
@ -237,7 +237,11 @@ class AsyncTestCase(unittest.TestCase):
return IOLoop() return IOLoop()
def _handle_exception(self, typ, value, tb): def _handle_exception(self, typ, value, tb):
self.__failure = (typ, value, tb) if self.__failure is None:
self.__failure = (typ, value, tb)
else:
app_log.error("multiple unhandled exceptions in test",
exc_info=(typ, value, tb))
self.stop() self.stop()
return True return True
@ -395,7 +399,8 @@ class AsyncHTTPTestCase(AsyncTestCase):
def tearDown(self): def tearDown(self):
self.http_server.stop() self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections) self.io_loop.run_sync(self.http_server.close_all_connections,
timeout=get_async_test_timeout())
if (not IOLoop.initialized() or if (not IOLoop.initialized() or
self.http_client.io_loop is not IOLoop.instance()): self.http_client.io_loop is not IOLoop.instance()):
self.http_client.close() self.http_client.close()

View File

@ -115,16 +115,17 @@ def import_object(name):
if type('') is not type(b''): if type('') is not type(b''):
def u(s): def u(s):
return s return s
bytes_type = bytes
unicode_type = str unicode_type = str
basestring_type = str basestring_type = str
else: else:
def u(s): def u(s):
return s.decode('unicode_escape') return s.decode('unicode_escape')
bytes_type = str
unicode_type = unicode unicode_type = unicode
basestring_type = basestring basestring_type = basestring
# Deprecated alias that was used before we dropped py25 support.
# Left here in case anyone outside Tornado is using it.
bytes_type = bytes
if sys.version_info > (3,): if sys.version_info > (3,):
exec(""" exec("""
@ -154,7 +155,7 @@ def errno_from_exception(e):
"""Provides the errno from an Exception object. """Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception the errno out of the args but if someone instantiates an Exception
without any args you will get a tuple error. So this function without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the abstracts all that behavior to give you a safe way to get the
errno. errno.
@ -202,7 +203,7 @@ class Configurable(object):
impl = cls impl = cls
args.update(kwargs) args.update(kwargs)
instance = super(Configurable, cls).__new__(impl) instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatiblity with AsyncHTTPClient # initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__ # singleton magic. If we get rid of that we can switch to __init__
# here too. # here too.
instance.initialize(**args) instance.initialize(**args)
@ -237,7 +238,7 @@ class Configurable(object):
some parameters. some parameters.
""" """
base = cls.configurable_base() base = cls.configurable_base()
if isinstance(impl, (unicode_type, bytes_type)): if isinstance(impl, (unicode_type, bytes)):
impl = import_object(impl) impl = import_object(impl)
if impl is not None and not issubclass(impl, cls): if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls) raise ValueError("Invalid subclass of %s" % cls)

View File

@ -35,8 +35,7 @@ Here is a simple "Hello, world" example app::
application.listen(8888) application.listen(8888)
tornado.ioloop.IOLoop.instance().start() tornado.ioloop.IOLoop.instance().start()
See the :doc:`Tornado overview <overview>` for more details and a good getting See the :doc:`guide` for additional information.
started guide.
Thread-safety notes Thread-safety notes
------------------- -------------------
@ -48,6 +47,7 @@ not thread-safe. In particular, methods such as
you use multiple threads it is important to use `.IOLoop.add_callback` you use multiple threads it is important to use `.IOLoop.add_callback`
to transfer control back to the main thread before finishing the to transfer control back to the main thread before finishing the
request. request.
""" """
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
@ -72,6 +72,7 @@ import time
import tornado import tornado
import traceback import traceback
import types import types
from io import BytesIO
from tornado.concurrent import Future, is_future from tornado.concurrent import Future, is_future
from tornado import escape from tornado import escape
@ -83,12 +84,8 @@ from tornado.log import access_log, app_log, gen_log
from tornado import stack_context from tornado import stack_context
from tornado import template from tornado import template
from tornado.escape import utf8, _unicode from tornado.escape import utf8, _unicode
from tornado.util import bytes_type, import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try: try:
import Cookie # py2 import Cookie # py2
@ -344,7 +341,7 @@ class RequestHandler(object):
_INVALID_HEADER_CHAR_RE = re.compile(br"[\x00-\x1f]") _INVALID_HEADER_CHAR_RE = re.compile(br"[\x00-\x1f]")
def _convert_header_value(self, value): def _convert_header_value(self, value):
if isinstance(value, bytes_type): if isinstance(value, bytes):
pass pass
elif isinstance(value, unicode_type): elif isinstance(value, unicode_type):
value = value.encode('utf-8') value = value.encode('utf-8')
@ -652,7 +649,7 @@ class RequestHandler(object):
raise RuntimeError("Cannot write() after finish(). May be caused " raise RuntimeError("Cannot write() after finish(). May be caused "
"by using async operations without the " "by using async operations without the "
"@asynchronous decorator.") "@asynchronous decorator.")
if not isinstance(chunk, (bytes_type, unicode_type, dict)): if not isinstance(chunk, (bytes, unicode_type, dict)):
raise TypeError("write() only accepts bytes, unicode, and dict objects") raise TypeError("write() only accepts bytes, unicode, and dict objects")
if isinstance(chunk, dict): if isinstance(chunk, dict):
chunk = escape.json_encode(chunk) chunk = escape.json_encode(chunk)
@ -677,7 +674,7 @@ class RequestHandler(object):
js_embed.append(utf8(embed_part)) js_embed.append(utf8(embed_part))
file_part = module.javascript_files() file_part = module.javascript_files()
if file_part: if file_part:
if isinstance(file_part, (unicode_type, bytes_type)): if isinstance(file_part, (unicode_type, bytes)):
js_files.append(file_part) js_files.append(file_part)
else: else:
js_files.extend(file_part) js_files.extend(file_part)
@ -686,7 +683,7 @@ class RequestHandler(object):
css_embed.append(utf8(embed_part)) css_embed.append(utf8(embed_part))
file_part = module.css_files() file_part = module.css_files()
if file_part: if file_part:
if isinstance(file_part, (unicode_type, bytes_type)): if isinstance(file_part, (unicode_type, bytes)):
css_files.append(file_part) css_files.append(file_part)
else: else:
css_files.extend(file_part) css_files.extend(file_part)
@ -919,7 +916,7 @@ class RequestHandler(object):
return return
self.clear() self.clear()
reason = None reason = kwargs.get('reason')
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
exception = kwargs['exc_info'][1] exception = kwargs['exc_info'][1]
if isinstance(exception, HTTPError) and exception.reason: if isinstance(exception, HTTPError) and exception.reason:
@ -959,12 +956,15 @@ class RequestHandler(object):
@property @property
def locale(self): def locale(self):
"""The local for the current session. """The locale for the current session.
Determined by either `get_user_locale`, which you can override to Determined by either `get_user_locale`, which you can override to
set the locale based on, e.g., a user preference stored in a set the locale based on, e.g., a user preference stored in a
database, or `get_browser_locale`, which uses the ``Accept-Language`` database, or `get_browser_locale`, which uses the ``Accept-Language``
header. header.
.. versionchanged: 4.1
Added a property setter.
""" """
if not hasattr(self, "_locale"): if not hasattr(self, "_locale"):
self._locale = self.get_user_locale() self._locale = self.get_user_locale()
@ -973,6 +973,10 @@ class RequestHandler(object):
assert self._locale assert self._locale
return self._locale return self._locale
@locale.setter
def locale(self, value):
self._locale = value
def get_user_locale(self): def get_user_locale(self):
"""Override to determine the locale from the authenticated user. """Override to determine the locale from the authenticated user.
@ -1128,14 +1132,15 @@ class RequestHandler(object):
else: else:
# Treat unknown versions as not present instead of failing. # Treat unknown versions as not present instead of failing.
return None, None, None return None, None, None
elif len(cookie) == 32: else:
version = 1 version = 1
token = binascii.a2b_hex(utf8(cookie)) try:
token = binascii.a2b_hex(utf8(cookie))
except (binascii.Error, TypeError):
token = utf8(cookie)
# We don't have a usable timestamp in older versions. # We don't have a usable timestamp in older versions.
timestamp = int(time.time()) timestamp = int(time.time())
return (version, token, timestamp) return (version, token, timestamp)
else:
return None, None, None
def check_xsrf_cookie(self): def check_xsrf_cookie(self):
"""Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument. """Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument.
@ -1627,7 +1632,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
**settings): **settings):
if transforms is None: if transforms is None:
self.transforms = [] self.transforms = []
if settings.get("gzip"): if settings.get("compress_response") or settings.get("gzip"):
self.transforms.append(GZipContentEncoding) self.transforms.append(GZipContentEncoding)
else: else:
self.transforms = transforms self.transforms = transforms
@ -2164,11 +2169,14 @@ class StaticFileHandler(RequestHandler):
if include_body: if include_body:
content = self.get_content(self.absolute_path, start, end) content = self.get_content(self.absolute_path, start, end)
if isinstance(content, bytes_type): if isinstance(content, bytes):
content = [content] content = [content]
for chunk in content: for chunk in content:
self.write(chunk) try:
yield self.flush() self.write(chunk)
yield self.flush()
except iostream.StreamClosedError:
return
else: else:
assert self.request.method == "HEAD" assert self.request.method == "HEAD"
@ -2335,7 +2343,7 @@ class StaticFileHandler(RequestHandler):
""" """
data = cls.get_content(abspath) data = cls.get_content(abspath)
hasher = hashlib.md5() hasher = hashlib.md5()
if isinstance(data, bytes_type): if isinstance(data, bytes):
hasher.update(data) hasher.update(data)
else: else:
for chunk in data: for chunk in data:
@ -2547,7 +2555,6 @@ class GZipContentEncoding(OutputTransform):
ctype = _unicode(headers.get("Content-Type", "")).split(";")[0] ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
self._gzipping = self._compressible_type(ctype) and \ self._gzipping = self._compressible_type(ctype) and \
(not finishing or len(chunk) >= self.MIN_LENGTH) and \ (not finishing or len(chunk) >= self.MIN_LENGTH) and \
(finishing or "Content-Length" not in headers) and \
("Content-Encoding" not in headers) ("Content-Encoding" not in headers)
if self._gzipping: if self._gzipping:
headers["Content-Encoding"] = "gzip" headers["Content-Encoding"] = "gzip"
@ -2555,7 +2562,14 @@ class GZipContentEncoding(OutputTransform):
self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value) self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value)
chunk = self.transform_chunk(chunk, finishing) chunk = self.transform_chunk(chunk, finishing)
if "Content-Length" in headers: if "Content-Length" in headers:
headers["Content-Length"] = str(len(chunk)) # The original content length is no longer correct.
# If this is the last (and only) chunk, we can set the new
# content-length; otherwise we remove it and fall back to
# chunked encoding.
if finishing:
headers["Content-Length"] = str(len(chunk))
else:
del headers["Content-Length"]
return status_code, headers, chunk return status_code, headers, chunk
def transform_chunk(self, chunk, finishing): def transform_chunk(self, chunk, finishing):
@ -2704,7 +2718,7 @@ class TemplateModule(UIModule):
def javascript_files(self): def javascript_files(self):
result = [] result = []
for f in self._get_resources("javascript_files"): for f in self._get_resources("javascript_files"):
if isinstance(f, (unicode_type, bytes_type)): if isinstance(f, (unicode_type, bytes)):
result.append(f) result.append(f)
else: else:
result.extend(f) result.extend(f)
@ -2716,7 +2730,7 @@ class TemplateModule(UIModule):
def css_files(self): def css_files(self):
result = [] result = []
for f in self._get_resources("css_files"): for f in self._get_resources("css_files"):
if isinstance(f, (unicode_type, bytes_type)): if isinstance(f, (unicode_type, bytes)):
result.append(f) result.append(f)
else: else:
result.extend(f) result.extend(f)
@ -2754,7 +2768,7 @@ class URLSpec(object):
in the regex will be passed in to the handler's get/post/etc in the regex will be passed in to the handler's get/post/etc
methods as arguments. methods as arguments.
* ``handler_class``: `RequestHandler` subclass to be invoked. * ``handler``: `RequestHandler` subclass to be invoked.
* ``kwargs`` (optional): A dictionary of additional arguments * ``kwargs`` (optional): A dictionary of additional arguments
to be passed to the handler's constructor. to be passed to the handler's constructor.
@ -2821,7 +2835,7 @@ class URLSpec(object):
return self._path return self._path
converted_args = [] converted_args = []
for a in args: for a in args:
if not isinstance(a, (unicode_type, bytes_type)): if not isinstance(a, (unicode_type, bytes)):
a = str(a) a = str(a)
converted_args.append(escape.url_escape(utf8(a), plus=False)) converted_args.append(escape.url_escape(utf8(a), plus=False))
return self._path % tuple(converted_args) return self._path % tuple(converted_args)

View File

@ -26,6 +26,7 @@ import os
import struct import struct
import tornado.escape import tornado.escape
import tornado.web import tornado.web
import zlib
from tornado.concurrent import TracebackFuture from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode from tornado.escape import utf8, native_str, to_unicode
@ -35,7 +36,7 @@ from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log from tornado.log import gen_log, app_log
from tornado import simple_httpclient from tornado import simple_httpclient
from tornado.tcpclient import TCPClient from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, _websocket_mask from tornado.util import _websocket_mask
try: try:
from urllib.parse import urlparse # py2 from urllib.parse import urlparse # py2
@ -105,6 +106,21 @@ class WebSocketHandler(tornado.web.RequestHandler):
}; };
This script pops up an alert box that says "You said: Hello, world". This script pops up an alert box that says "You said: Hello, world".
Web browsers allow any site to open a websocket connection to any other,
instead of using the same-origin policy that governs other network
access from javascript. This can be surprising and is a potential
security hole, so since Tornado 4.0 `WebSocketHandler` requires
applications that wish to receive cross-origin websockets to opt in
by overriding the `~WebSocketHandler.check_origin` method (see that
method's docs for details). Failure to do so is the most likely
cause of 403 errors when making a websocket connection.
When using a secure websocket connection (``wss://``) with a self-signed
certificate, the connection from a browser may fail because it wants
to show the "accept this certificate" dialog but has nowhere to show it.
You must first visit a regular HTML page using the same certificate
to accept it before the websocket connection will succeed.
""" """
def __init__(self, application, request, **kwargs): def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request, tornado.web.RequestHandler.__init__(self, application, request,
@ -156,13 +172,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.stream.set_close_callback(self.on_connection_close) self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self) self.ws_connection = WebSocketProtocol13(
self, compression_options=self.get_compression_options())
self.ws_connection.accept_connection() self.ws_connection.accept_connection()
else: else:
self.stream.write(tornado.escape.utf8( if not self.stream.closed():
"HTTP/1.1 426 Upgrade Required\r\n" self.stream.write(tornado.escape.utf8(
"Sec-WebSocket-Version: 8\r\n\r\n")) "HTTP/1.1 426 Upgrade Required\r\n"
self.stream.close() "Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
def write_message(self, message, binary=False): def write_message(self, message, binary=False):
@ -198,6 +216,19 @@ class WebSocketHandler(tornado.web.RequestHandler):
""" """
return None return None
def get_compression_options(self):
"""Override to return compression options for the connection.
If this method returns None (the default), compression will
be disabled. If it returns a dict (even an empty one), it
will be enabled. The contents of the dict may be used to
control the memory and CPU usage of the compression,
but no such options are currently implemented.
.. versionadded:: 4.1
"""
return None
def open(self): def open(self):
"""Invoked when a new WebSocket is opened. """Invoked when a new WebSocket is opened.
@ -275,6 +306,19 @@ class WebSocketHandler(tornado.web.RequestHandler):
browsers, since WebSockets are allowed to bypass the usual same-origin browsers, since WebSockets are allowed to bypass the usual same-origin
policies and don't use CORS headers. policies and don't use CORS headers.
To accept all cross-origin traffic (which was the default prior to
Tornado 4.0), simply override this method to always return true::
def check_origin(self, origin):
return True
To allow connections from any subdomain of your site, you might
do something like::
def check_origin(self, origin):
parsed_origin = urllib.parse.urlparse(origin)
return parsed_origin.netloc.endswith(".mydomain.com")
.. versionadded:: 4.0 .. versionadded:: 4.0
""" """
parsed_origin = urlparse(origin) parsed_origin = urlparse(origin)
@ -308,6 +352,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.ws_connection = None self.ws_connection = None
self.on_close() self.on_close()
def send_error(self, *args, **kwargs):
if self.stream is None:
super(WebSocketHandler, self).send_error(*args, **kwargs)
else:
# If we get an uncaught exception during the handshake,
# we have no choice but to abruptly close the connection.
# TODO: for uncaught exceptions after the handshake,
# we can close the connection more gracefully.
self.stream.close()
def _wrap_method(method): def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs): def _disallow_for_websocket(self, *args, **kwargs):
@ -316,7 +369,7 @@ def _wrap_method(method):
else: else:
raise RuntimeError("Method not supported for Web Sockets") raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie", for method in ["write", "redirect", "set_header", "set_cookie",
"set_status", "flush", "finish"]: "set_status", "flush", "finish"]:
setattr(WebSocketHandler, method, setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method))) _wrap_method(getattr(WebSocketHandler, method)))
@ -355,13 +408,68 @@ class WebSocketProtocol(object):
self.close() # let the subclass cleanup self.close() # let the subclass cleanup
class _PerMessageDeflateCompressor(object):
def __init__(self, persistent, max_wbits):
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
# There is no symbolic constant for the minimum wbits value.
if not (8 <= max_wbits <= zlib.MAX_WBITS):
raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
self._compressor = self._create_compressor()
else:
self._compressor = None
def _create_compressor(self):
return zlib.compressobj(-1, zlib.DEFLATED, -self._max_wbits)
def compress(self, data):
compressor = self._compressor or self._create_compressor()
data = (compressor.compress(data) +
compressor.flush(zlib.Z_SYNC_FLUSH))
assert data.endswith(b'\x00\x00\xff\xff')
return data[:-4]
class _PerMessageDeflateDecompressor(object):
def __init__(self, persistent, max_wbits):
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
self._decompressor = self._create_decompressor()
else:
self._decompressor = None
def _create_decompressor(self):
return zlib.decompressobj(-self._max_wbits)
def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
return decompressor.decompress(data + b'\x00\x00\xff\xff')
class WebSocketProtocol13(WebSocketProtocol): class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455. """Implementation of the WebSocket protocol from RFC 6455.
This class supports versions 7 and 8 of the protocol in addition to the This class supports versions 7 and 8 of the protocol in addition to the
final version 13. final version 13.
""" """
def __init__(self, handler, mask_outgoing=False): # Bit masks for the first byte of a frame.
FIN = 0x80
RSV1 = 0x40
RSV2 = 0x20
RSV3 = 0x10
RSV_MASK = RSV1 | RSV2 | RSV3
OPCODE_MASK = 0x0f
def __init__(self, handler, mask_outgoing=False,
compression_options=None):
WebSocketProtocol.__init__(self, handler) WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing self.mask_outgoing = mask_outgoing
self._final_frame = False self._final_frame = False
@ -372,6 +480,19 @@ class WebSocketProtocol13(WebSocketProtocol):
self._fragmented_message_buffer = None self._fragmented_message_buffer = None
self._fragmented_message_opcode = None self._fragmented_message_opcode = None
self._waiting = None self._waiting = None
self._compression_options = compression_options
self._decompressor = None
self._compressor = None
self._frame_compressed = None
# The total uncompressed size of all messages received or sent.
# Unicode messages are encoded to utf8.
# Only for testing; subject to change.
self._message_bytes_in = 0
self._message_bytes_out = 0
# The total size of all packets received or sent. Includes
# the effect of compression, frame overhead, and control frames.
self._wire_bytes_in = 0
self._wire_bytes_out = 0
def accept_connection(self): def accept_connection(self):
try: try:
@ -416,24 +537,99 @@ class WebSocketProtocol13(WebSocketProtocol):
assert selected in subprotocols assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
extension_header = ''
extensions = self._parse_extensions_header(self.request.headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
# TODO: negotiate parameters if compression_options
# specifies limits.
self._create_compressors('server', ext[1])
if ('client_max_window_bits' in ext[1] and
ext[1]['client_max_window_bits'] is None):
# Don't echo an offered client_max_window_bits
# parameter with no value.
del ext[1]['client_max_window_bits']
extension_header = ('Sec-WebSocket-Extensions: %s\r\n' %
httputil._encode_header(
'permessage-deflate', ext[1]))
break
if self.stream.closed():
self._abort()
return
self.stream.write(tornado.escape.utf8( self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n" "HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n" "Upgrade: websocket\r\n"
"Connection: Upgrade\r\n" "Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n" "Sec-WebSocket-Accept: %s\r\n"
"%s" "%s%s"
"\r\n" % (self._challenge_response(), subprotocol_header))) "\r\n" % (self._challenge_response(),
subprotocol_header, extension_header)))
self._run_callback(self.handler.open, *self.handler.open_args, self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs) **self.handler.open_kwargs)
self._receive_frame() self._receive_frame()
def _write_frame(self, fin, opcode, data): def _parse_extensions_header(self, headers):
extensions = headers.get("Sec-WebSocket-Extensions", '')
if extensions:
return [httputil._parse_header(e.strip())
for e in extensions.split(',')]
return []
def _process_server_headers(self, key, headers):
"""Process the headers sent by the server to this client connection.
'key' is the websocket handshake challenge/response key.
"""
assert headers['Upgrade'].lower() == 'websocket'
assert headers['Connection'].lower() == 'upgrade'
accept = self.compute_accept_value(key)
assert headers['Sec-Websocket-Accept'] == accept
extensions = self._parse_extensions_header(headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
self._create_compressors('client', ext[1])
else:
raise ValueError("unsupported extension %r", ext)
def _get_compressor_options(self, side, agreed_parameters):
"""Converts a websocket agreed_parameters set to keyword arguments
for our compressor objects.
"""
options = dict(
persistent=(side + '_no_context_takeover') not in agreed_parameters)
wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
if wbits_header is None:
options['max_wbits'] = zlib.MAX_WBITS
else:
options['max_wbits'] = int(wbits_header)
return options
def _create_compressors(self, side, agreed_parameters):
# TODO: handle invalid parameters gracefully
allowed_keys = set(['server_no_context_takeover',
'client_no_context_takeover',
'server_max_window_bits',
'client_max_window_bits'])
for key in agreed_parameters:
if key not in allowed_keys:
raise ValueError("unsupported compression parameter %r" % key)
other_side = 'client' if (side == 'server') else 'server'
self._compressor = _PerMessageDeflateCompressor(
**self._get_compressor_options(side, agreed_parameters))
self._decompressor = _PerMessageDeflateDecompressor(
**self._get_compressor_options(other_side, agreed_parameters))
def _write_frame(self, fin, opcode, data, flags=0):
if fin: if fin:
finbit = 0x80 finbit = self.FIN
else: else:
finbit = 0 finbit = 0
frame = struct.pack("B", finbit | opcode) frame = struct.pack("B", finbit | opcode | flags)
l = len(data) l = len(data)
if self.mask_outgoing: if self.mask_outgoing:
mask_bit = 0x80 mask_bit = 0x80
@ -449,7 +645,11 @@ class WebSocketProtocol13(WebSocketProtocol):
mask = os.urandom(4) mask = os.urandom(4)
data = mask + _websocket_mask(mask, data) data = mask + _websocket_mask(mask, data)
frame += data frame += data
self.stream.write(frame) self._wire_bytes_out += len(frame)
try:
self.stream.write(frame)
except StreamClosedError:
self._abort()
def write_message(self, message, binary=False): def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.""" """Sends the given message to the client of this Web Socket."""
@ -458,15 +658,17 @@ class WebSocketProtocol13(WebSocketProtocol):
else: else:
opcode = 0x1 opcode = 0x1
message = tornado.escape.utf8(message) message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type) assert isinstance(message, bytes)
try: self._message_bytes_out += len(message)
self._write_frame(True, opcode, message) flags = 0
except StreamClosedError: if self._compressor:
self._abort() message = self._compressor.compress(message)
flags |= self.RSV1
self._write_frame(True, opcode, message, flags=flags)
def write_ping(self, data): def write_ping(self, data):
"""Send ping frame.""" """Send ping frame."""
assert isinstance(data, bytes_type) assert isinstance(data, bytes)
self._write_frame(True, 0x9, data) self._write_frame(True, 0x9, data)
def _receive_frame(self): def _receive_frame(self):
@ -476,11 +678,15 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort() self._abort()
def _on_frame_start(self, data): def _on_frame_start(self, data):
self._wire_bytes_in += len(data)
header, payloadlen = struct.unpack("BB", data) header, payloadlen = struct.unpack("BB", data)
self._final_frame = header & 0x80 self._final_frame = header & self.FIN
reserved_bits = header & 0x70 reserved_bits = header & self.RSV_MASK
self._frame_opcode = header & 0xf self._frame_opcode = header & self.OPCODE_MASK
self._frame_opcode_is_control = self._frame_opcode & 0x8 self._frame_opcode_is_control = self._frame_opcode & 0x8
if self._decompressor is not None:
self._frame_compressed = bool(reserved_bits & self.RSV1)
reserved_bits &= ~self.RSV1
if reserved_bits: if reserved_bits:
# client is using as-yet-undefined extensions; abort # client is using as-yet-undefined extensions; abort
self._abort() self._abort()
@ -506,6 +712,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort() self._abort()
def _on_frame_length_16(self, data): def _on_frame_length_16(self, data):
self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!H", data)[0] self._frame_length = struct.unpack("!H", data)[0]
try: try:
if self._masked_frame: if self._masked_frame:
@ -516,6 +723,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort() self._abort()
def _on_frame_length_64(self, data): def _on_frame_length_64(self, data):
self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!Q", data)[0] self._frame_length = struct.unpack("!Q", data)[0]
try: try:
if self._masked_frame: if self._masked_frame:
@ -526,6 +734,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort() self._abort()
def _on_masking_key(self, data): def _on_masking_key(self, data):
self._wire_bytes_in += len(data)
self._frame_mask = data self._frame_mask = data
try: try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
@ -533,9 +742,11 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort() self._abort()
def _on_masked_frame_data(self, data): def _on_masked_frame_data(self, data):
# Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
self._on_frame_data(_websocket_mask(self._frame_mask, data)) self._on_frame_data(_websocket_mask(self._frame_mask, data))
def _on_frame_data(self, data): def _on_frame_data(self, data):
self._wire_bytes_in += len(data)
if self._frame_opcode_is_control: if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented # control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with # data frames, so control frames must not interact with
@ -576,8 +787,12 @@ class WebSocketProtocol13(WebSocketProtocol):
if self.client_terminated: if self.client_terminated:
return return
if self._frame_compressed:
data = self._decompressor.decompress(data)
if opcode == 0x1: if opcode == 0x1:
# UTF-8 data # UTF-8 data
self._message_bytes_in += len(data)
try: try:
decoded = data.decode("utf-8") decoded = data.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
@ -586,7 +801,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self._run_callback(self.handler.on_message, decoded) self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2: elif opcode == 0x2:
# Binary data # Binary data
self._run_callback(self.handler.on_message, decoded) self._message_bytes_in += len(data)
self._run_callback(self.handler.on_message, data)
elif opcode == 0x8: elif opcode == 0x8:
# Close # Close
self.client_terminated = True self.client_terminated = True
@ -636,7 +852,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
This class should not be instantiated directly; use the This class should not be instantiated directly; use the
`websocket_connect` function instead. `websocket_connect` function instead.
""" """
def __init__(self, io_loop, request): def __init__(self, io_loop, request, compression_options=None):
self.compression_options = compression_options
self.connect_future = TracebackFuture() self.connect_future = TracebackFuture()
self.read_future = None self.read_future = None
self.read_queue = collections.deque() self.read_queue = collections.deque()
@ -651,6 +868,14 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
'Sec-WebSocket-Key': self.key, 'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13', 'Sec-WebSocket-Version': '13',
}) })
if self.compression_options is not None:
# Always offer to let the server set our max_wbits (and even though
# we don't offer it, we will accept a client_no_context_takeover
# from the server).
# TODO: set server parameters for deflate extension
# if requested in self.compression_options.
request.headers['Sec-WebSocket-Extensions'] = (
'permessage-deflate; client_max_window_bits')
self.tcp_client = TCPClient(io_loop=io_loop) self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__( super(WebSocketClientConnection, self).__init__(
@ -673,10 +898,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.protocol.close(code, reason) self.protocol.close(code, reason)
self.protocol = None self.protocol = None
def _on_close(self): def on_connection_close(self):
if not self.connect_future.done():
self.connect_future.set_exception(StreamClosedError())
self.on_message(None) self.on_message(None)
self.resolver.close() self.tcp_client.close()
super(WebSocketClientConnection, self)._on_close() super(WebSocketClientConnection, self).on_connection_close()
def _on_http_response(self, response): def _on_http_response(self, response):
if not self.connect_future.done(): if not self.connect_future.done():
@ -692,12 +919,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
start_line, headers) start_line, headers)
self.headers = headers self.headers = headers
assert self.headers['Upgrade'].lower() == 'websocket' self.protocol = WebSocketProtocol13(
assert self.headers['Connection'].lower() == 'upgrade' self, mask_outgoing=True,
accept = WebSocketProtocol13.compute_accept_value(self.key) compression_options=self.compression_options)
assert self.headers['Sec-Websocket-Accept'] == accept self.protocol._process_server_headers(self.key, self.headers)
self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
self.protocol._receive_frame() self.protocol._receive_frame()
if self._timeout is not None: if self._timeout is not None:
@ -705,7 +930,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self._timeout = None self._timeout = None
self.stream = self.connection.detach() self.stream = self.connection.detach()
self.stream.set_close_callback(self._on_close) self.stream.set_close_callback(self.on_connection_close)
# Once we've taken over the connection, clear the final callback
# we set on the http request. This deactivates the error handling
# in simple_httpclient that would otherwise interfere with our
# ability to see exceptions.
self.final_callback = None
self.connect_future.set_result(self) self.connect_future.set_result(self)
@ -742,14 +972,21 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
pass pass
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None): def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
compression_options=None):
"""Client-side websocket support. """Client-side websocket support.
Takes a url and returns a Future whose result is a Takes a url and returns a Future whose result is a
`WebSocketClientConnection`. `WebSocketClientConnection`.
``compression_options`` is interpreted in the same way as the
return value of `.WebSocketHandler.get_compression_options`.
.. versionchanged:: 3.2 .. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls. Also accepts ``HTTPRequest`` objects in place of urls.
.. versionchanged:: 4.1
Added ``compression_options``.
""" """
if io_loop is None: if io_loop is None:
io_loop = IOLoop.current() io_loop = IOLoop.current()
@ -763,7 +1000,7 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy( request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS) request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request) conn = WebSocketClientConnection(io_loop, request, compression_options)
if callback is not None: if callback is not None:
io_loop.add_future(conn.connect_future, callback) io_loop.add_future(conn.connect_future, callback)
return conn.connect_future return conn.connect_future

View File

@ -32,6 +32,7 @@ provides WSGI support in two ways:
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
import sys import sys
from io import BytesIO
import tornado import tornado
from tornado.concurrent import Future from tornado.concurrent import Future
@ -40,12 +41,8 @@ from tornado import httputil
from tornado.log import access_log from tornado.log import access_log
from tornado import web from tornado import web
from tornado.escape import native_str from tornado.escape import native_str
from tornado.util import bytes_type, unicode_type from tornado.util import unicode_type
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try: try:
import urllib.parse as urllib_parse # py3 import urllib.parse as urllib_parse # py3
@ -58,7 +55,7 @@ except ImportError:
# here to minimize the temptation to use them in non-wsgi contexts. # here to minimize the temptation to use them in non-wsgi contexts.
if str is unicode_type: if str is unicode_type:
def to_wsgi_str(s): def to_wsgi_str(s):
assert isinstance(s, bytes_type) assert isinstance(s, bytes)
return s.decode('latin1') return s.decode('latin1')
def from_wsgi_str(s): def from_wsgi_str(s):
@ -66,7 +63,7 @@ if str is unicode_type:
return s.encode('latin1') return s.encode('latin1')
else: else:
def to_wsgi_str(s): def to_wsgi_str(s):
assert isinstance(s, bytes_type) assert isinstance(s, bytes)
return s return s
def from_wsgi_str(s): def from_wsgi_str(s):