summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSolly Ross <sross@redhat.com>2015-05-06 13:49:13 -0400
committerSolly Ross <sross@redhat.com>2015-05-12 12:53:43 -0400
commit52f6830852216fc80ab2505b0ca0f5bbcd450d5a (patch)
treef5966c6f59c24a6f44bd3d1f309c26409207aed6
parentf6e9fbe5cf096d3bbab1ce64f765ee934a0adefe (diff)
downloadwebsockify-52f6830852216fc80ab2505b0ca0f5bbcd450d5a.tar.gz
Update Tests and Test Plugins
This commit updates the unit tests to work with the current code and adds in tests for the auth and token plugin functionality.
-rw-r--r--.gitignore2
-rw-r--r--tests/test_websocket.py420
-rw-r--r--tests/test_websocketproxy.py199
-rw-r--r--tox.ini (renamed from tests/tox.ini)5
-rw-r--r--websockify/websocket.py2
5 files changed, 370 insertions, 258 deletions
diff --git a/.gitignore b/.gitignore
index 3bf91dd..3ac5ff1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,5 @@ other/node_modules
.pydevproject
target.cfg
target.cfg.d
+.tox
+*.egg-info
diff --git a/tests/test_websocket.py b/tests/test_websocket.py
index c7a106f..acd7699 100644
--- a/tests/test_websocket.py
+++ b/tests/test_websocket.py
@@ -26,201 +26,303 @@ import stubout
import sys
import tempfile
import unittest
-from ssl import SSLError
-from websockify import websocket as websocket
-from SimpleHTTPServer import SimpleHTTPRequestHandler
+import socket
+import signal
+from websockify import websocket
+
+try:
+ from SimpleHTTPServer import SimpleHTTPRequestHandler
+except ImportError:
+ from http.server import SimpleHTTPRequestHandler
+
+try:
+ from StringIO import StringIO
+ BytesIO = StringIO
+except ImportError:
+ from io import StringIO
+ from io import BytesIO
+
+
+
+
+def raise_oserror(*args, **kwargs):
+ raise OSError('fake error')
-class MockConnection(object):
- def __init__(self, path):
- self.path = path
+class FakeSocket(object):
+ def __init__(self, data=''):
+ if isinstance(data, bytes):
+ self._data = data
+ else:
+ self._data = data.encode('latin_1')
- def makefile(self, mode='r', bufsize=-1):
- return open(self.path, mode, bufsize)
+ def recv(self, amt, flags=None):
+ res = self._data[0:amt]
+ if not (flags & socket.MSG_PEEK):
+ self._data = self._data[amt:]
+ return res
-class WebSocketTestCase(unittest.TestCase):
+ def makefile(self, mode='r', buffsize=None):
+ if 'b' in mode:
+ return BytesIO(self._data)
+ else:
+ return StringIO(self._data.decode('latin_1'))
- def _init_logger(self, tmpdir):
- name = 'websocket-unittest'
- logger = logging.getLogger(name)
- logger.setLevel(logging.DEBUG)
- logger.propagate = True
- filename = "%s.log" % (name)
- handler = logging.FileHandler(filename)
- handler.setFormatter(logging.Formatter("%(message)s"))
- logger.addHandler(handler)
+class WebSocketRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
- """Called automatically before each test."""
- super(WebSocketTestCase, self).setUp()
+ super(WebSocketRequestHandlerTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting()
- # Temporary dir for test data
- self.tmpdir = tempfile.mkdtemp()
- # Put log somewhere persistent
- self._init_logger('./')
+ self.tmpdir = tempfile.mkdtemp('-websockify-tests')
# Mock this out cause it screws tests up
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
- self.server = self._get_websockserver(daemon=True,
- ssl_only=False)
- self.soc = self.server.socket('localhost')
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ lambda *args, **kwargs: None)
def tearDown(self):
"""Called automatically after each test."""
self.stubs.UnsetAll()
- shutil.rmtree(self.tmpdir)
- super(WebSocketTestCase, self).tearDown()
+ os.rmdir(self.tmpdir)
+ super(WebSocketRequestHandlerTestCase, self).tearDown()
+
+ def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
+ **kwargs):
+ web = kwargs.pop('web', self.tmpdir)
+ return websocket.WebSocketServer(
+ handler_class, listen_host='localhost',
+ listen_port=80, key=self.tmpdir, web=web,
+ record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
+ **kwargs)
- def _get_websockserver(self, **kwargs):
- return websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key=self.tmpdir,
- web=self.tmpdir,
- record=self.tmpdir,
- **kwargs)
+ def test_normal_get_with_only_upgrade_returns_error(self):
+ server = self._get_server(web=None)
+ handler = websocket.WebSocketRequestHandler(
+ FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
- def _mock_os_open_oserror(self, file, flags):
- raise OSError('')
+ def fake_send_response(self, code, message=None):
+ self.last_code = code
- def _mock_os_close_oserror(self, fd):
- raise OSError('')
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ fake_send_response)
- def _mock_os_close_oserror_EBADF(self, fd):
- raise OSError(errno.EBADF, '')
+ handler.do_GET()
+ self.assertEqual(handler.last_code, 405)
- def _mock_socket(self, *args, **kwargs):
- return self.soc
+ def test_list_dir_with_file_only_returns_error(self):
+ server = self._get_server(file_only=True)
+ handler = websocket.WebSocketRequestHandler(
+ FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
- def _mock_select(self, rlist, wlist, xlist, timeout=None):
- return '_mock_select'
+ def fake_send_response(self, code, message=None):
+ self.last_code = code
- def _mock_select_exception(self, rlist, wlist, xlist, timeout=None):
- raise Exception
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ fake_send_response)
- def _mock_select_keyboardinterrupt(self, rlist, wlist,
- xlist, timeout=None):
- raise KeyboardInterrupt
+ handler.path = '/'
+ handler.do_GET()
+ self.assertEqual(handler.last_code, 404)
- def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None):
- sys.exit()
- def test_daemonize_error(self):
- soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
- self.stubs.Set(os, 'fork', lambda *args: None)
+class WebSocketServerTestCase(unittest.TestCase):
+ def setUp(self):
+ super(WebSocketServerTestCase, self).setUp()
+ self.stubs = stubout.StubOutForTesting()
+ self.tmpdir = tempfile.mkdtemp('-websockify-tests')
+ # Mock this out cause it screws tests up
+ self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
+
+ def tearDown(self):
+ """Called automatically after each test."""
+ self.stubs.UnsetAll()
+ os.rmdir(self.tmpdir)
+ super(WebSocketServerTestCase, self).tearDown()
+
+ def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
+ **kwargs):
+ return websocket.WebSocketServer(
+ handler_class, listen_host='localhost',
+ listen_port=80, key=self.tmpdir, web=self.tmpdir,
+ record=self.tmpdir, **kwargs)
+
+ def test_daemonize_raises_error_while_closing_fds(self):
+ server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
+ self.stubs.Set(os, 'fork', lambda *args: 0)
+ self.stubs.Set(signal, 'signal', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None)
- self.stubs.Set(os, 'close', self._mock_os_close_oserror)
- self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
+ self.stubs.Set(os, 'close', raise_oserror)
+ self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
+
+ def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
+ def raise_oserror_ebadf(fd):
+ raise OSError(errno.EBADF, 'fake error')
- def test_daemonize_EBADF_error(self):
- soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
- self.stubs.Set(os, 'fork', lambda *args: None)
+ server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
+ self.stubs.Set(os, 'fork', lambda *args: 0)
self.stubs.Set(os, 'setsid', lambda *args: None)
- self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF)
- self.stubs.Set(os, 'open', self._mock_os_open_oserror)
- self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
-
- def test_decode_hybi(self):
- soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
- self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
- base64=True)
-
- def test_do_websocket_handshake(self):
- soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
- soc.scheme = 'scheme'
- headers = {'Sec-WebSocket-Protocol': 'binary',
- 'Sec-WebSocket-Version': '7',
- 'Sec-WebSocket-Key': 'foo'}
- soc.do_websocket_handshake(headers, '127.0.0.1')
-
- def test_do_handshake(self):
- soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
- self.stubs.Set(select, 'select', self._mock_select)
- self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv')
- self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
-
- def test_do_handshake_ssl_error(self):
- soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
-
- def _mock_wrap_socket(*args, **kwargs):
- from ssl import SSLError
- raise SSLError('unit test exception')
-
- self.stubs.Set(select, 'select', self._mock_select)
- self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16')
- self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket)
- self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
-
- def test_fallback_SIGCHILD(self):
- soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
- soc.fallback_SIGCHLD(None, None)
-
- def test_start_server_Exception(self):
- soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
- self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
+ self.stubs.Set(signal, 'signal', lambda *args: None)
+ self.stubs.Set(os, 'close', raise_oserror_ebadf)
+ self.stubs.Set(os, 'open', raise_oserror)
+ self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
+
+ def test_handshake_fails_on_not_ready(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocket.WebSocketServer.EClose, server.do_handshake,
+ FakeSocket(), '127.0.0.1')
+
+ def test_empty_handshake_fails(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
+
+ sock = FakeSocket('')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocket.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_handshake_policy_request(self):
+ # TODO(directxman12): implement
+ pass
+
+ def test_handshake_ssl_only_without_ssl_raises_error(self):
+ server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
+
+ sock = FakeSocket('some initial data')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocket.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_do_handshake_no_ssl(self):
+ class FakeHandler(object):
+ CALLED = False
+ def __init__(self, *args, **kwargs):
+ type(self).CALLED = True
+
+ FakeHandler.CALLED = False
+
+ server = self._get_server(
+ handler_class=FakeHandler, daemon=True,
+ ssl_only=0, idle_timeout=1)
+
+ sock = FakeSocket('some initial data')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
+ self.assertTrue(FakeHandler.CALLED, True)
+
+ def test_do_handshake_ssl(self):
+ # TODO(directxman12): implement this
+ pass
+
+ def test_do_handshake_ssl_without_ssl_raises_error(self):
+ # TODO(directxman12): implement this
+ pass
+
+ def test_do_handshake_ssl_without_cert_raises_error(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
+ cert='afdsfasdafdsafdsafdsafdas')
+
+ sock = FakeSocket("\x16some ssl data")
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocket.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_do_handshake_ssl_error_eof_raises_close_error(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
+
+ sock = FakeSocket("\x16some ssl data")
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ def fake_wrap_socket(*args, **kwargs):
+ raise ssl.SSLError(ssl.SSL_ERROR_EOF)
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket)
+ self.assertRaises(
+ websocket.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_fallback_sigchld_handler(self):
+ # TODO(directxman12): implement this
+ pass
+
+ def test_start_server_error(self):
+ server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
+ sock = server.socket('localhost')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ raise Exception("fake error")
+
+ self.stubs.Set(websocket.WebSocketServer, 'socket',
+ lambda *args, **kwargs: sock)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
- self.stubs.Set(select, 'select', self._mock_select_exception)
- self.assertEqual(None, soc.start_server())
+ self.stubs.Set(select, 'select', fake_select)
+ server.start_server()
+
+ def test_start_server_keyboardinterrupt(self):
+ server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
+ sock = server.socket('localhost')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ raise KeyboardInterrupt
- def test_start_server_KeyboardInterrupt(self):
- soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
- self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
+ self.stubs.Set(websocket.WebSocketServer, 'socket',
+ lambda *args, **kwargs: sock)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
- self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt)
- self.assertEqual(None, soc.start_server())
+ self.stubs.Set(select, 'select', fake_select)
+ server.start_server()
def test_start_server_systemexit(self):
- websocket.ssl = None
- self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
+ server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
+ sock = server.socket('localhost')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ sys.exit()
+
+ self.stubs.Set(websocket.WebSocketServer, 'socket',
+ lambda *args, **kwargs: sock)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
- self.stubs.Set(select, 'select', self._mock_select_systemexit)
- soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1,
- verbose=True)
- self.assertEqual(None, soc.start_server())
-
- def test_WSRequestHandle_do_GET_nofile(self):
- request = 'GET /tmp.txt HTTP/0.9'
- with tempfile.NamedTemporaryFile() as test_file:
- test_file.write(request)
- test_file.flush()
- test_file.seek(0)
- con = MockConnection(test_file.name)
- soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True)
- soc.path = ''
- soc.headers = {'upgrade': ''}
- self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
- lambda *args: None)
- soc.do_GET()
- self.assertEqual(404, soc.last_code)
-
- def test_WSRequestHandle_do_GET_hidden_resource(self):
- request = 'GET /tmp.txt HTTP/0.9'
- with tempfile.NamedTemporaryFile() as test_file:
- test_file.write(request)
- test_file.flush()
- test_file.seek(0)
- con = MockConnection(test_file.name)
- soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True)
- soc.path = test_file.name + '?'
- soc.headers = {'upgrade': ''}
- soc.webroot = 'no match startswith'
- self.stubs.Set(SimpleHTTPRequestHandler,
- 'send_response',
- lambda *args: None)
- soc.do_GET()
- self.assertEqual(403, soc.last_code)
-
- def testsocket_set_keepalive_options(self):
+ self.stubs.Set(select, 'select', fake_select)
+ server.start_server()
+
+ def test_socket_set_keepalive_options(self):
keepcnt = 12
keepidle = 34
keepintvl = 56
- sock = self.server.socket('localhost',
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
+ server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
+ sock = server.socket('localhost',
+ tcp_keepcnt=keepcnt,
+ tcp_keepidle=keepidle,
+ tcp_keepintvl=keepintvl)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
@@ -229,11 +331,11 @@ class WebSocketTestCase(unittest.TestCase):
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
- sock = self.server.socket('localhost',
- tcp_keepalive=False,
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
+ sock = server.socket('localhost',
+ tcp_keepalive=False,
+ tcp_keepcnt=keepcnt,
+ tcp_keepidle=keepidle,
+ tcp_keepintvl=keepintvl)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py
index cf940ae..8103ef6 100644
--- a/tests/test_websocketproxy.py
+++ b/tests/test_websocketproxy.py
@@ -1,6 +1,6 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
-# Copyright(c)2013 NTT corp. All Rights Reserved.
+# Copyright(c) 2015 Red Hat, Inc All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -15,113 +15,122 @@
# under the License.
""" Unit tests for websocketproxy """
-import os
-import logging
-import select
-import shutil
-import stubout
-import subprocess
-import tempfile
-import time
+
import unittest
+import unittest
+import socket
+import stubout
+
+from websockify import websocket
from websockify import websocketproxy
+from websockify import token_plugins
+from websockify import auth_plugins
+try:
+ from StringIO import StringIO
+ BytesIO = StringIO
+except ImportError:
+ from io import StringIO
+ from io import BytesIO
-class MockSocket(object):
- def __init__(*args, **kwargs):
- pass
- def shutdown(*args):
- pass
+class FakeSocket(object):
+ def __init__(self, data=''):
+ if isinstance(data, bytes):
+ self._data = data
+ else:
+ self._data = data.encode('latin_1')
- def close(*args):
- pass
+ def recv(self, amt, flags=None):
+ res = self._data[0:amt]
+ if not (flags & socket.MSG_PEEK):
+ self._data = self._data[amt:]
+
+ return res
+
+ def makefile(self, mode='r', buffsize=None):
+ if 'b' in mode:
+ return BytesIO(self._data)
+ else:
+ return StringIO(self._data.decode('latin_1'))
-class WebSocketProxyTest(unittest.TestCase):
+class FakeServer(object):
+ class EClose(Exception):
+ pass
- def _init_logger(self, tmpdir):
- name = 'websocket-unittest'
- logger = logging.getLogger(name)
- logger.setLevel(logging.DEBUG)
- logger.propagate = True
- filename = "%s.log" % (name)
- handler = logging.FileHandler(filename)
- handler.setFormatter(logging.Formatter("%(message)s"))
- logger.addHandler(handler)
+ def __init__(self):
+ self.token_plugin = None
+ self.auth_plugin = None
+ self.wrap_cmd = None
+ self.ssl_target = None
+ self.unix_target = None
+class ProxyRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
- """Called automatically before each test."""
- super(WebSocketProxyTest, self).setUp()
- self.soc = ''
+ super(ProxyRequestHandlerTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting()
- # Temporary dir for test data
- self.tmpdir = tempfile.mkdtemp()
- # Put log somewhere persistent
- self._init_logger('./')
- # Mock this out cause it screws tests up
- self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
+ self.handler = websocketproxy.ProxyRequestHandler(
+ FakeSocket(''), "127.0.0.1", FakeServer())
+ self.handler.path = "https://localhost:6080/websockify?token=blah"
+ self.handler.headers = None
+ self.stubs.Set(websocket.WebSocketServer, 'socket',
+ staticmethod(lambda *args, **kwargs: None))
def tearDown(self):
- """Called automatically after each test."""
self.stubs.UnsetAll()
- shutil.rmtree(self.tmpdir)
- super(WebSocketProxyTest, self).tearDown()
-
- def _get_websockproxy(self, **kwargs):
- return websocketproxy.WebSocketProxy(key=self.tmpdir,
- web=self.tmpdir,
- record=self.tmpdir,
- **kwargs)
-
- def test_run_wrap_cmd(self):
- web_socket_proxy = self._get_websockproxy()
- web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
-
- def mock_Popen(*args, **kwargs):
- return '_mock_cmd'
-
- self.stubs.Set(subprocess, 'Popen', mock_Popen)
- web_socket_proxy.run_wrap_cmd()
- self.assertEquals(web_socket_proxy.spawn_message, True)
-
- def test_started(self):
- web_socket_proxy = self._get_websockproxy()
- web_socket_proxy.__dict__["spawn_message"] = False
- web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
-
- def mock_run_wrap_cmd(*args, **kwargs):
- web_socket_proxy.__dict__["spawn_message"] = True
-
- self.stubs.Set(web_socket_proxy, 'run_wrap_cmd', mock_run_wrap_cmd)
- web_socket_proxy.started()
- self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True)
-
- def test_poll(self):
- web_socket_proxy = self._get_websockproxy()
- web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
- web_socket_proxy.__dict__["wrap_mode"] = "respawn"
- web_socket_proxy.__dict__["wrap_times"] = [99999999]
- web_socket_proxy.__dict__["spawn_message"] = True
- web_socket_proxy.__dict__["cmd"] = None
- self.stubs.Set(time, 'time', lambda: 100000000.000)
- web_socket_proxy.poll()
- self.assertEquals(web_socket_proxy.spawn_message, False)
-
- def test_new_client(self):
- web_socket_proxy = self._get_websockproxy()
- web_socket_proxy.__dict__["verbose"] = "verbose"
- web_socket_proxy.__dict__["daemon"] = None
- web_socket_proxy.__dict__["client"] = "client"
-
- self.stubs.Set(web_socket_proxy, 'socket', MockSocket)
-
- def mock_select(*args, **kwargs):
- ins = None
- outs = None
- excepts = "excepts"
- return ins, outs, excepts
-
- self.stubs.Set(select, 'select', mock_select)
- self.assertRaises(Exception, web_socket_proxy.new_websocket_client)
+ super(ProxyRequestHandlerTestCase, self).tearDown()
+
+ def test_get_target(self):
+ class TestPlugin(token_plugins.BasePlugin):
+ def lookup(self, token):
+ return ("some host", "some port")
+
+ host, port = self.handler.get_target(
+ TestPlugin(None), self.handler.path)
+
+ self.assertEqual(host, "some host")
+ self.assertEqual(port, "some port")
+
+ def test_get_target_raises_error_on_unknown_token(self):
+ class TestPlugin(token_plugins.BasePlugin):
+ def lookup(self, token):
+ return None
+
+ self.assertRaises(FakeServer.EClose, self.handler.get_target,
+ TestPlugin(None), "https://localhost:6080/websockify?token=blah")
+
+ def test_token_plugin(self):
+ class TestPlugin(token_plugins.BasePlugin):
+ def lookup(self, token):
+ return (self.source + token).split(',')
+
+ self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
+ lambda *args, **kwargs: None)
+
+ self.handler.server.token_plugin = TestPlugin("somehost,")
+ self.handler.new_websocket_client()
+
+ self.assertEqual(self.handler.server.target_host, "somehost")
+ self.assertEqual(self.handler.server.target_port, "blah")
+
+ def test_auth_plugin(self):
+ class TestPlugin(auth_plugins.BasePlugin):
+ def authenticate(self, headers, target_host, target_port):
+ if target_host == self.source:
+ raise auth_plugins.AuthenticationError("some error")
+
+ self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
+ staticmethod(lambda *args, **kwargs: None))
+
+ self.handler.server.auth_plugin = TestPlugin("somehost")
+ self.handler.server.target_host = "somehost"
+ self.handler.server.target_port = "someport"
+
+ self.assertRaises(auth_plugins.AuthenticationError,
+ self.handler.new_websocket_client)
+
+ self.handler.server.target_host = "someotherhost"
+ self.handler.new_websocket_client()
+
diff --git a/tests/tox.ini b/tox.ini
index 098e89c..012d349 100644
--- a/tests/tox.ini
+++ b/tox.ini
@@ -4,8 +4,7 @@
# and then run "tox" from this directory.
[tox]
-envlist = py24,py25,py26,py27,py30
-setupdir = ../
+envlist = py24,py26,py27,py33,py34
[testenv]
commands = nosetests {posargs}
@@ -13,7 +12,7 @@ deps =
mox
nose
-# At some point we should enable this since tox epdctes it to exist but
+# At some point we should enable this since tox expects it to exist but
# the code will need pep8ising first.
#[testenv:pep8]
#commands = flake8
diff --git a/websockify/websocket.py b/websockify/websocket.py
index 20305b8..727413a 100644
--- a/websockify/websocket.py
+++ b/websockify/websocket.py
@@ -790,7 +790,7 @@ class WebSocketServer(object):
handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % handshake)
- if handshake == "":
+ if not handshake:
raise self.EClose("ignoring empty handshake")
elif handshake.startswith(s2b("<policy-file-request/>")):