diff options
Diffstat (limited to 'tests/test_server.py')
-rw-r--r-- | tests/test_server.py | 533 |
1 files changed, 533 insertions, 0 deletions
diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..9134fb8 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,533 @@ +import errno +import socket +import unittest + +dummy_app = object() + + +class TestWSGIServer(unittest.TestCase): + def _makeOne( + self, + application=dummy_app, + host="127.0.0.1", + port=0, + _dispatcher=None, + adj=None, + map=None, + _start=True, + _sock=None, + _server=None, + ): + from waitress.server import create_server + + self.inst = create_server( + application, + host=host, + port=port, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + ) + return self.inst + + def _makeOneWithMap( + self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app + ): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + return self._makeOne( + app, + host=host, + port=port, + map=map, + _sock=sock, + _dispatcher=task_dispatcher, + _start=_start, + ) + + def _makeOneWithMulti( + self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0" + ): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + from waitress.server import create_server + + self.inst = create_server( + app, + listen=listen, + map=map, + _dispatcher=task_dispatcher, + _start=_start, + _sock=sock, + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + if self.inst is not None: + self.inst.close() + + def test_ctor_app_is_None(self): + self.inst = None + self.assertRaises(ValueError, self._makeOneWithMap, app=None) + + def test_ctor_start_true(self): + inst = self._makeOneWithMap(_start=True) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + + def test_ctor_makes_dispatcher(self): + inst = self._makeOne(_start=False, map={}) + self.assertEqual( + inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher" + ) + + def test_ctor_start_false(self): + inst = self._makeOneWithMap(_start=False) + self.assertEqual(inst.accepting, False) + + def test_get_server_name_empty(self): + inst = self._makeOneWithMap(_start=False) + self.assertRaises(ValueError, inst.get_server_name, "") + + def test_get_server_name_with_ip(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("127.0.0.1") + self.assertTrue(result) + + def test_get_server_name_with_hostname(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("fred.flintstone.com") + self.assertEqual(result, "fred.flintstone.com") + + def test_get_server_name_0000(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("0.0.0.0") + self.assertTrue(len(result) != 0) + + def test_get_server_name_double_colon(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("::") + self.assertTrue(len(result) != 0) + + def test_get_server_name_ipv6(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("2001:DB8::ffff") + self.assertEqual("[2001:DB8::ffff]", result) + + def test_get_server_multi(self): + inst = self._makeOneWithMulti() + self.assertEqual(inst.__class__.__name__, "MultiSocketServer") + + def test_run(self): + inst = self._makeOneWithMap(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + + def test_run_base_server(self): + inst = self._makeOneWithMulti(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + + def test_pull_trigger(self): + inst = self._makeOneWithMap(_start=False) + inst.trigger.close() + inst.trigger = DummyTrigger() + inst.pull_trigger() + self.assertEqual(inst.trigger.pulled, True) + + def test_add_task(self): + task = DummyTask() + inst = self._makeOneWithMap() + inst.add_task(task) + self.assertEqual(inst.task_dispatcher.tasks, [task]) + self.assertFalse(task.serviced) + + def test_readable_not_accepting(self): + inst = self._makeOneWithMap() + inst.accepting = False + self.assertFalse(inst.readable()) + + def test_readable_maplen_gt_connection_limit(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {"a": 1, "b": 2} + self.assertFalse(inst.readable()) + + def test_readable_maplen_lt_connection_limit(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {} + self.assertTrue(inst.readable()) + + def test_readable_maintenance_false(self): + import time + + inst = self._makeOneWithMap() + then = time.time() + 1000 + inst.next_channel_cleanup = then + L = [] + inst.maintenance = lambda t: L.append(t) + inst.readable() + self.assertEqual(L, []) + self.assertEqual(inst.next_channel_cleanup, then) + + def test_readable_maintenance_true(self): + inst = self._makeOneWithMap() + inst.next_channel_cleanup = 0 + L = [] + inst.maintenance = lambda t: L.append(t) + inst.readable() + self.assertEqual(len(L), 1) + self.assertNotEqual(inst.next_channel_cleanup, 0) + + def test_writable(self): + inst = self._makeOneWithMap() + self.assertFalse(inst.writable()) + + def test_handle_read(self): + inst = self._makeOneWithMap() + self.assertEqual(inst.handle_read(), None) + + def test_handle_connect(self): + inst = self._makeOneWithMap() + self.assertEqual(inst.handle_connect(), None) + + def test_handle_accept_wouldblock_socket_error(self): + inst = self._makeOneWithMap() + ewouldblock = socket.error(errno.EWOULDBLOCK) + inst.socket = DummySock(toraise=ewouldblock) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, False) + + def test_handle_accept_other_socket_error(self): + inst = self._makeOneWithMap() + eaborted = socket.error(errno.ECONNABORTED) + inst.socket = DummySock(toraise=eaborted) + inst.adj = DummyAdj + + def foo(): + raise socket.error + + inst.accept = foo + inst.logger = DummyLogger() + inst.handle_accept() + self.assertEqual(inst.socket.accepted, False) + self.assertEqual(len(inst.logger.logged), 1) + + def test_handle_accept_noerror(self): + inst = self._makeOneWithMap() + innersock = DummySock() + inst.socket = DummySock(acceptresult=(innersock, None)) + inst.adj = DummyAdj + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + def test_maintenance(self): + inst = self._makeOneWithMap() + + class DummyChannel(object): + requests = [] + + zombie = DummyChannel() + zombie.last_activity = 0 + zombie.running_tasks = False + inst.active_channels[100] = zombie + inst.maintenance(10000) + self.assertEqual(zombie.will_close, True) + + def test_backward_compatibility(self): + from waitress.server import WSGIServer, TcpWSGIServer + from waitress.adjustments import Adjustments + + self.assertTrue(WSGIServer is TcpWSGIServer) + self.inst = WSGIServer(None, _start=False, port=1234) + # Ensure the adjustment was actually applied. + self.assertNotEqual(Adjustments.port, 1234) + self.assertEqual(self.inst.adj.port, 1234) + + def test_create_with_one_tcp_socket(self): + from waitress.server import TcpWSGIServer + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, TcpWSGIServer)) + + def test_create_with_multiple_tcp_sockets(self): + from waitress.server import MultiSocketServer + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + sockets[0].bind(("127.0.0.1", 0)) + sockets[1].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, MultiSocketServer)) + self.assertEqual(len(inst.effective_listen), 2) + + def test_create_with_one_socket_should_not_bind_socket(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + sockets[0].bind_called = False + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertEqual(inst.socket.bound, ("127.0.0.1", 80)) + self.assertFalse(inst.socket.bind_called) + + def test_create_with_one_socket_handle_accept_noerror(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + inst = self._makeWithSockets(sockets=sockets) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.adj = DummyAdj + inst.handle_accept() + self.assertEqual(sockets[0].accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + +if hasattr(socket, "AF_UNIX"): + + class TestUnixWSGIServer(unittest.TestCase): + unix_socket = "/tmp/waitress.test.sock" + + def _makeOne(self, _start=True, _sock=None): + from waitress.server import create_server + + self.inst = create_server( + dummy_app, + map={}, + _start=_start, + _sock=_sock, + _dispatcher=DummyTaskDispatcher(), + unix_socket=self.unix_socket, + unix_socket_perms="600", + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + self.inst.close() + + def _makeDummy(self, *args, **kwargs): + sock = DummySock(*args, **kwargs) + sock.family = socket.AF_UNIX + return sock + + def test_unix(self): + inst = self._makeOne(_start=False) + self.assertEqual(inst.socket.family, socket.AF_UNIX) + self.assertEqual(inst.socket.getsockname(), self.unix_socket) + + def test_handle_accept(self): + # Working on the assumption that we only have to test the happy path + # for Unix domain sockets as the other paths should've been covered + # by inet sockets. + client = self._makeDummy() + listen = self._makeDummy(acceptresult=(client, None)) + inst = self._makeOne(_sock=listen) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(client.opts, []) + self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) + + def test_creates_new_sockinfo(self): + from waitress.server import UnixWSGIServer + + self.inst = UnixWSGIServer( + dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600" + ) + + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) + + def test_create_with_unix_socket(self): + from waitress.server import ( + MultiSocketServer, + BaseWSGIServer, + TcpWSGIServer, + UnixWSGIServer, + ) + + sockets = [ + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list( + filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) + ) + self.assertTrue(isinstance(server[0], UnixWSGIServer)) + self.assertTrue(isinstance(server[1], UnixWSGIServer)) + + +class DummySock(socket.socket): + accepted = False + blocking = False + family = socket.AF_INET + type = socket.SOCK_STREAM + proto = 0 + + def __init__(self, toraise=None, acceptresult=(None, None)): + self.toraise = toraise + self.acceptresult = acceptresult + self.bound = None + self.opts = [] + self.bind_called = False + + def bind(self, addr): + self.bind_called = True + self.bound = addr + + def accept(self): + if self.toraise: + raise self.toraise + self.accepted = True + return self.acceptresult + + def setblocking(self, x): + self.blocking = True + + def fileno(self): + return 10 + + def getpeername(self): + return "127.0.0.1" + + def setsockopt(self, *arg): + self.opts.append(arg) + + def getsockopt(self, *arg): + return 1 + + def listen(self, num): + self.listened = num + + def getsockname(self): + return self.bound + + def close(self): + pass + + +class DummyTaskDispatcher(object): + def __init__(self): + self.tasks = [] + + def add_task(self, task): + self.tasks.append(task) + + def shutdown(self): + self.was_shutdown = True + + +class DummyTask(object): + serviced = False + start_response_called = False + wrote_header = False + status = "200 OK" + + def __init__(self): + self.response_headers = {} + self.written = "" + + def service(self): # pragma: no cover + self.serviced = True + + +class DummyAdj: + connection_limit = 1 + log_socket_errors = True + socket_options = [("level", "optname", "value")] + cleanup_interval = 900 + channel_timeout = 300 + + +class DummyAsyncore(object): + def loop(self, timeout=30.0, use_poll=False, map=None, count=None): + raise SystemExit + + +class DummyTrigger(object): + def pull_trigger(self): + self.pulled = True + + def close(self): + pass + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, **kw): + self.logged.append(msg) |