2014-06-11 04:34:28 -04:00
|
|
|
from __future__ import absolute_import, division, print_function, with_statement
|
|
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
from tornado.concurrent import Future
|
|
|
|
from tornado.httpclient import HTTPError, HTTPRequest
|
|
|
|
from tornado.log import gen_log
|
|
|
|
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
2014-06-17 00:54:00 -04:00
|
|
|
from tornado.test.util import unittest
|
2014-06-11 04:34:28 -04:00
|
|
|
from tornado.web import Application, RequestHandler
|
|
|
|
|
|
|
|
try:
|
|
|
|
import tornado.websocket
|
|
|
|
from tornado.util import _websocket_mask_python
|
|
|
|
except ImportError:
|
|
|
|
# The unittest module presents misleading errors on ImportError
|
|
|
|
# (it acts as if websocket_test could not be found, hiding the underlying
|
|
|
|
# error). If we get an ImportError here (which could happen due to
|
|
|
|
# TORNADO_EXTENSION=1), print some extra information before failing.
|
|
|
|
traceback.print_exc()
|
|
|
|
raise
|
|
|
|
|
|
|
|
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
|
|
|
|
|
|
|
|
try:
|
|
|
|
from tornado import speedups
|
|
|
|
except ImportError:
|
|
|
|
speedups = None
|
|
|
|
|
|
|
|
|
|
|
|
class TestWebSocketHandler(WebSocketHandler):
|
|
|
|
"""Base class for testing handlers that exposes the on_close event.
|
|
|
|
|
|
|
|
This allows for deterministic cleanup of the associated socket.
|
|
|
|
"""
|
|
|
|
def initialize(self, close_future):
|
|
|
|
self.close_future = close_future
|
|
|
|
|
|
|
|
def on_close(self):
|
2014-06-17 00:54:00 -04:00
|
|
|
self.close_future.set_result((self.close_code, self.close_reason))
|
2014-06-11 04:34:28 -04:00
|
|
|
|
|
|
|
|
|
|
|
class EchoHandler(TestWebSocketHandler):
|
|
|
|
def on_message(self, message):
|
|
|
|
self.write_message(message, isinstance(message, bytes))
|
|
|
|
|
|
|
|
|
|
|
|
class HeaderHandler(TestWebSocketHandler):
|
|
|
|
def open(self):
|
2014-06-17 00:54:00 -04:00
|
|
|
try:
|
|
|
|
# In a websocket context, many RequestHandler methods
|
|
|
|
# raise RuntimeErrors.
|
|
|
|
self.set_status(503)
|
|
|
|
raise Exception("did not get expected exception")
|
|
|
|
except RuntimeError:
|
|
|
|
pass
|
2014-06-11 04:34:28 -04:00
|
|
|
self.write_message(self.request.headers.get('X-Test', ''))
|
|
|
|
|
|
|
|
|
|
|
|
class NonWebSocketHandler(RequestHandler):
|
|
|
|
def get(self):
|
|
|
|
self.write('ok')
|
|
|
|
|
|
|
|
|
2014-06-17 00:54:00 -04:00
|
|
|
class CloseReasonHandler(TestWebSocketHandler):
|
|
|
|
def open(self):
|
|
|
|
self.close(1001, "goodbye")
|
|
|
|
|
|
|
|
|
2014-06-11 04:34:28 -04:00
|
|
|
class WebSocketTest(AsyncHTTPTestCase):
|
|
|
|
def get_app(self):
|
|
|
|
self.close_future = Future()
|
|
|
|
return Application([
|
|
|
|
('/echo', EchoHandler, dict(close_future=self.close_future)),
|
|
|
|
('/non_ws', NonWebSocketHandler),
|
|
|
|
('/header', HeaderHandler, dict(close_future=self.close_future)),
|
2014-06-17 00:54:00 -04:00
|
|
|
('/close_reason', CloseReasonHandler,
|
|
|
|
dict(close_future=self.close_future)),
|
2014-06-11 04:34:28 -04:00
|
|
|
])
|
|
|
|
|
2014-06-17 00:54:00 -04:00
|
|
|
def test_http_request(self):
|
|
|
|
# WS server, HTTP client.
|
|
|
|
response = self.fetch('/echo')
|
|
|
|
self.assertEqual(response.code, 400)
|
|
|
|
|
2014-06-11 04:34:28 -04:00
|
|
|
@gen_test
|
|
|
|
def test_websocket_gen(self):
|
|
|
|
ws = yield websocket_connect(
|
|
|
|
'ws://localhost:%d/echo' % self.get_http_port(),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
ws.write_message('hello')
|
|
|
|
response = yield ws.read_message()
|
|
|
|
self.assertEqual(response, 'hello')
|
|
|
|
ws.close()
|
|
|
|
yield self.close_future
|
|
|
|
|
|
|
|
def test_websocket_callbacks(self):
|
|
|
|
websocket_connect(
|
|
|
|
'ws://localhost:%d/echo' % self.get_http_port(),
|
|
|
|
io_loop=self.io_loop, callback=self.stop)
|
|
|
|
ws = self.wait().result()
|
|
|
|
ws.write_message('hello')
|
|
|
|
ws.read_message(self.stop)
|
|
|
|
response = self.wait().result()
|
|
|
|
self.assertEqual(response, 'hello')
|
2014-06-17 00:54:00 -04:00
|
|
|
self.close_future.add_done_callback(lambda f: self.stop())
|
2014-06-11 04:34:28 -04:00
|
|
|
ws.close()
|
2014-06-17 00:54:00 -04:00
|
|
|
self.wait()
|
2014-06-11 04:34:28 -04:00
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_websocket_http_fail(self):
|
|
|
|
with self.assertRaises(HTTPError) as cm:
|
|
|
|
yield websocket_connect(
|
|
|
|
'ws://localhost:%d/notfound' % self.get_http_port(),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
self.assertEqual(cm.exception.code, 404)
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_websocket_http_success(self):
|
|
|
|
with self.assertRaises(WebSocketError):
|
|
|
|
yield websocket_connect(
|
|
|
|
'ws://localhost:%d/non_ws' % self.get_http_port(),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_websocket_network_fail(self):
|
|
|
|
sock, port = bind_unused_port()
|
|
|
|
sock.close()
|
2014-06-17 00:54:00 -04:00
|
|
|
with self.assertRaises(IOError):
|
2014-06-11 04:34:28 -04:00
|
|
|
with ExpectLog(gen_log, ".*"):
|
|
|
|
yield websocket_connect(
|
|
|
|
'ws://localhost:%d/' % port,
|
|
|
|
io_loop=self.io_loop,
|
|
|
|
connect_timeout=3600)
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_websocket_close_buffered_data(self):
|
|
|
|
ws = yield websocket_connect(
|
|
|
|
'ws://localhost:%d/echo' % self.get_http_port())
|
|
|
|
ws.write_message('hello')
|
|
|
|
ws.write_message('world')
|
|
|
|
ws.stream.close()
|
|
|
|
yield self.close_future
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_websocket_headers(self):
|
|
|
|
# Ensure that arbitrary headers can be passed through websocket_connect.
|
|
|
|
ws = yield websocket_connect(
|
|
|
|
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
|
|
|
|
headers={'X-Test': 'hello'}))
|
|
|
|
response = yield ws.read_message()
|
|
|
|
self.assertEqual(response, 'hello')
|
|
|
|
ws.close()
|
|
|
|
yield self.close_future
|
|
|
|
|
2014-06-17 00:54:00 -04:00
|
|
|
@gen_test
|
|
|
|
def test_server_close_reason(self):
|
|
|
|
ws = yield websocket_connect(
|
|
|
|
'ws://localhost:%d/close_reason' % self.get_http_port())
|
|
|
|
msg = yield ws.read_message()
|
|
|
|
# A message of None means the other side closed the connection.
|
|
|
|
self.assertIs(msg, None)
|
|
|
|
self.assertEqual(ws.close_code, 1001)
|
|
|
|
self.assertEqual(ws.close_reason, "goodbye")
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_client_close_reason(self):
|
|
|
|
ws = yield websocket_connect(
|
|
|
|
'ws://localhost:%d/echo' % self.get_http_port())
|
|
|
|
ws.close(1001, 'goodbye')
|
|
|
|
code, reason = yield self.close_future
|
|
|
|
self.assertEqual(code, 1001)
|
|
|
|
self.assertEqual(reason, 'goodbye')
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_check_origin_valid_no_path(self):
|
|
|
|
port = self.get_http_port()
|
|
|
|
|
|
|
|
url = 'ws://localhost:%d/echo' % port
|
|
|
|
headers = {'Origin': 'http://localhost:%d' % port}
|
|
|
|
|
|
|
|
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
ws.write_message('hello')
|
|
|
|
response = yield ws.read_message()
|
|
|
|
self.assertEqual(response, 'hello')
|
|
|
|
ws.close()
|
|
|
|
yield self.close_future
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_check_origin_valid_with_path(self):
|
|
|
|
port = self.get_http_port()
|
|
|
|
|
|
|
|
url = 'ws://localhost:%d/echo' % port
|
|
|
|
headers = {'Origin': 'http://localhost:%d/something' % port}
|
|
|
|
|
|
|
|
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
ws.write_message('hello')
|
|
|
|
response = yield ws.read_message()
|
|
|
|
self.assertEqual(response, 'hello')
|
|
|
|
ws.close()
|
|
|
|
yield self.close_future
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_check_origin_invalid_partial_url(self):
|
|
|
|
port = self.get_http_port()
|
|
|
|
|
|
|
|
url = 'ws://localhost:%d/echo' % port
|
|
|
|
headers = {'Origin': 'localhost:%d' % port}
|
|
|
|
|
|
|
|
with self.assertRaises(HTTPError) as cm:
|
|
|
|
yield websocket_connect(HTTPRequest(url, headers=headers),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
self.assertEqual(cm.exception.code, 403)
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_check_origin_invalid(self):
|
|
|
|
port = self.get_http_port()
|
|
|
|
|
|
|
|
url = 'ws://localhost:%d/echo' % port
|
|
|
|
# Host is localhost, which should not be accessible from some other
|
|
|
|
# domain
|
|
|
|
headers = {'Origin': 'http://somewhereelse.com'}
|
|
|
|
|
|
|
|
with self.assertRaises(HTTPError) as cm:
|
|
|
|
yield websocket_connect(HTTPRequest(url, headers=headers),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
|
|
|
|
self.assertEqual(cm.exception.code, 403)
|
|
|
|
|
|
|
|
@gen_test
|
|
|
|
def test_check_origin_invalid_subdomains(self):
|
|
|
|
port = self.get_http_port()
|
|
|
|
|
|
|
|
url = 'ws://localhost:%d/echo' % port
|
|
|
|
# Subdomains should be disallowed by default. If we could pass a
|
|
|
|
# resolver to websocket_connect we could test sibling domains as well.
|
|
|
|
headers = {'Origin': 'http://subtenant.localhost'}
|
|
|
|
|
|
|
|
with self.assertRaises(HTTPError) as cm:
|
|
|
|
yield websocket_connect(HTTPRequest(url, headers=headers),
|
|
|
|
io_loop=self.io_loop)
|
|
|
|
|
|
|
|
self.assertEqual(cm.exception.code, 403)
|
|
|
|
|
2014-06-11 04:34:28 -04:00
|
|
|
|
|
|
|
class MaskFunctionMixin(object):
|
|
|
|
# Subclasses should define self.mask(mask, data)
|
|
|
|
def test_mask(self):
|
|
|
|
self.assertEqual(self.mask(b'abcd', b''), b'')
|
|
|
|
self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
|
|
|
|
self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
|
|
|
|
self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
|
|
|
|
# Include test cases with \x00 bytes (to ensure that the C
|
|
|
|
# extension isn't depending on null-terminated strings) and
|
|
|
|
# bytes with the high bit set (to smoke out signedness issues).
|
|
|
|
self.assertEqual(self.mask(b'\x00\x01\x02\x03',
|
|
|
|
b'\xff\xfb\xfd\xfc\xfe\xfa'),
|
|
|
|
b'\xff\xfa\xff\xff\xfe\xfb')
|
|
|
|
self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
|
|
|
|
b'\x00\x01\x02\x03\x04\x05'),
|
|
|
|
b'\xff\xfa\xff\xff\xfb\xfe')
|
|
|
|
|
|
|
|
|
|
|
|
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
|
|
|
def mask(self, mask, data):
|
|
|
|
return _websocket_mask_python(mask, data)
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
|
|
|
|
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
|
|
|
def mask(self, mask, data):
|
|
|
|
return speedups.websocket_mask(mask, data)
|