diff options
author | Frank Krick <frank.krick@gmail.com> | 2018-10-28 16:03:09 -0400 |
---|---|---|
committer | Frank Krick <frank.krick@gmail.com> | 2018-10-28 16:03:09 -0400 |
commit | de0a6c1ef714ce42c07a595238663931a056a9b7 (patch) | |
tree | d9e843a8d7d89c00fe4f35736adac903ffded04d /waitress | |
parent | 5274c8b95c47704462ab2c52d647c90230d5d801 (diff) | |
download | waitress-de0a6c1ef714ce42c07a595238663931a056a9b7.tar.gz |
Added sockets and bind_sockets parameters for server creation
Diffstat (limited to 'waitress')
-rw-r--r-- | waitress/adjustments.py | 24 | ||||
-rw-r--r-- | waitress/server.py | 90 | ||||
-rw-r--r-- | waitress/tests/test_adjustments.py | 73 | ||||
-rw-r--r-- | waitress/tests/test_server.py | 89 |
4 files changed, 242 insertions, 34 deletions
diff --git a/waitress/adjustments.py b/waitress/adjustments.py index c55823a..92d5a4f 100644 --- a/waitress/adjustments.py +++ b/waitress/adjustments.py @@ -69,6 +69,11 @@ def slash_fixed_str(s): def str_iftruthy(s): return str(s) if s else None +def as_socket_list(sockets): + """Checks if the elements in the list are of type socket and + returns None if not.""" + return [sock for sock in sockets if isinstance(sock, socket.socket)] + class _str_marker(str): pass @@ -106,6 +111,8 @@ class Adjustments(object): ('asyncore_use_poll', asbool), ('unix_socket', str), ('unix_socket_perms', asoctal), + ('sockets', as_socket_list), + ('bind_sockets', asbool), ) _param_map = dict(_params) @@ -216,11 +223,28 @@ class Adjustments(object): # Enable IPv6 by default ipv6 = True + # A list of sockets that waitress will use to accept connections. They can + # be used for e.g. socket activation + sockets = [] + + # Enable binding to sockets by default. This can be turned off for sockets + # that are supplied from the outside, e.g. using socket activation + bind_sockets = True + def __init__(self, **kw): if 'listen' in kw and ('host' in kw or 'port' in kw): raise ValueError('host and or port may not be set if listen is set.') + if 'listen' in kw and 'sockets' in kw: + raise ValueError('socket may not be set if listen is set.') + + if 'sockets' in kw and ('host' in kw or 'port' in kw): + raise ValueError('host and or port may not be set if sockets is set.') + + if 'sockets' in kw and ('bind_sockets' not in kw or kw['bind_sockets']): + raise ValueError('Sockets passed should be bound already, please turn bind of with bind_sockets=False') + for k, v in kw.items(): if k not in self._param_map: raise ValueError('Unknown adjustment %r' % k) diff --git a/waitress/server.py b/waitress/server.py index 198d5fa..ec12417 100644 --- a/waitress/server.py +++ b/waitress/server.py @@ -69,23 +69,47 @@ def create_server(application, effective_listen = [] last_serv = None - for sockinfo in adj.listen: - # When TcpWSGIServer is called, it registers itself in the map. This - # side-effect is all we need it for, so we don't store a reference to - # or return it to the user. - last_serv = TcpWSGIServer( - application, - map, - _start, - _sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo) - effective_listen.append((last_serv.effective_host, last_serv.effective_port)) + if not adj.sockets: + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo) + effective_listen.append((last_serv.effective_host, last_serv.effective_port)) + + for sock in adj.sockets: + if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) + last_serv = TcpWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo) + effective_listen.append((last_serv.effective_host, last_serv.effective_port)) + elif hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: + last_serv = UnixWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=(sock.family, sock.type, sock.proto, sock.getsockname())) + effective_listen.append((last_serv.effective_host, last_serv.effective_port)) # We are running a single server, so we can just return the last server, # saves us from having to create one more object - if len(adj.listen) == 1: + if len(adj.listen) == 1 and len(adj.sockets) == 0 or len(adj.sockets) == 1: # In this case we have no need to use a MultiSocketServer return last_serv @@ -181,7 +205,10 @@ class BaseWSGIServer(wasyncore.dispatcher, object): self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) self.set_reuse_addr() - self.bind_server_socket() + + if adj.bind_sockets: + self.bind_server_socket() + self.effective_host, self.effective_port = self.getsockname() self.server_name = self.get_server_name(self.effective_host) self.active_channels = {} @@ -225,7 +252,21 @@ class BaseWSGIServer(wasyncore.dispatcher, object): return server_name def getsockname(self): - raise NotImplementedError # pragma: no cover + try: + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICSERV + ) + except: # pragma: no cover + # This only happens on Linux because a DNS issue is considered a + # temporary failure that will raise (even when NI_NAMEREQD is not + # set). Instead we try again, but this time we just ask for the + # numerichost and the numericserv (port) and return those. It is + # better than nothing. + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV + ) def accept_connections(self): self.accepting = True @@ -313,23 +354,6 @@ class TcpWSGIServer(BaseWSGIServer): (_, _, _, sockaddr) = self.sockinfo self.bind(sockaddr) - def getsockname(self): - try: - return self.socketmod.getnameinfo( - self.socket.getsockname(), - self.socketmod.NI_NUMERICSERV - ) - except: # pragma: no cover - # This only happens on Linux because a DNS issue is considered a - # temporary failure that will raise (even when NI_NAMEREQD is not - # set). Instead we try again, but this time we just ask for the - # numerichost and the numericserv (port) and return those. It is - # better than nothing. - return self.socketmod.getnameinfo( - self.socket.getsockname(), - self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV - ) - def set_socket_options(self, conn): for (level, optname, value) in self.adj.socket_options: conn.setsockopt(level, optname, value) diff --git a/waitress/tests/test_adjustments.py b/waitress/tests/test_adjustments.py index 09c60ef..05c4a23 100644 --- a/waitress/tests/test_adjustments.py +++ b/waitress/tests/test_adjustments.py @@ -49,6 +49,31 @@ class Test_asbool(unittest.TestCase): result = self._callFUT(1) self.assertEqual(result, True) +class Test_as_socket_list(unittest.TestCase): + + def test_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM)] + if hasattr(sockets, 'AF_UNIX'): + sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) + new_sockets = as_socket_list(sockets) + self.assertEqual(sockets, new_sockets) + for sock in sockets: + sock.close() + + def test_not_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + {'something': 'else'}] + new_sockets = as_socket_list(sockets) + self.assertEqual(new_sockets, [sockets[0], sockets[1]]) + for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]: + sock.close() + class TestAdjustments(unittest.TestCase): def _hasIPv6(self): # pragma: nocover @@ -210,6 +235,54 @@ class TestAdjustments(unittest.TestCase): listen='127.0.0.1:8080', ) + def test_good_sockets(self): + sockets = [ + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + inst = self._makeOne(sockets=sockets, bind_sockets=False) + self.assertEqual(inst.sockets, sockets) + sockets[0].close() + sockets[1].close() + + def test_dont_use_sockets_with_bind_enabled(self): + sockets = [ + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, + self._makeOne, + sockets=sockets, + bind_sockets=True) + sockets[0].close() + sockets[1].close() + + def test_dont_mix_sockets_and_listen(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, + self._makeOne, + listen='127.0.0.1:8080', + sockets=sockets) + sockets[0].close() + + def test_dont_mix_sockets_and_host_port(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, + self._makeOne, + host='localhost', + port='8080', + sockets=sockets) + sockets[0].close() + + def test_default_bind_sockets(self): + inst = self._makeOne() + self.assertEqual(inst.bind_sockets, True) + + def test_good_bind_sockets(self): + inst = self._makeOne(bind_sockets=False) + self.assertEqual(inst.bind_sockets, False) + def test_badvar(self): self.assertRaises(ValueError, self._makeOne, nope=True) diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py index 0aa217a..38ee46d 100644 --- a/waitress/tests/test_server.py +++ b/waitress/tests/test_server.py @@ -50,6 +50,22 @@ class TestWSGIServer(unittest.TestCase): _sock=sock) return self.inst + def _makeWithSockets(self, application=dummy_app, _dispatcher=None, map=None, + _start=True, _sock=None, _server=None, sockets=None): + from waitress.server import create_server + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + bind_sockets=False, + sockets=_sockets) + return self.inst + def tearDown(self): if self.inst is not None: self.inst.close() @@ -237,6 +253,44 @@ class TestWSGIServer(unittest.TestCase): self.assertNotEqual(Adjustments.port, 1234) self.assertEqual(self.inst.adj.port, 1234) + def test_create_with_one_tcp_socket(self): + from waitress.server import TcpWSGIServer + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, TcpWSGIServer)) + + def test_create_with_multiple_tcp_sockets(self): + from waitress.server import MultiSocketServer + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM)] + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, MultiSocketServer)) + self.assertEqual(len(inst.effective_listen), 2) + + def test_create_with_one_socket_should_not_bind_socket(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(('127.0.0.1', 80)) + sockets[0].bind_called = False + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertEqual(inst.socket.bound, ('127.0.0.1', 80)) + self.assertFalse(inst.socket.bind_called) + + def test_create_with_one_socket_handle_accept_noerror(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(('127.0.0.1', 80)) + inst = self._makeWithSockets(sockets=sockets) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.adj = DummyAdj + inst.handle_accept() + self.assertEqual(sockets[0].accepted, True) + self.assertEqual(innersock.opts, [('level', 'optname', 'value')]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + if hasattr(socket, 'AF_UNIX'): class TestUnixWSGIServer(unittest.TestCase): @@ -255,6 +309,22 @@ if hasattr(socket, 'AF_UNIX'): ) return self.inst + def _makeWithSockets(self, application=dummy_app, _dispatcher=None, map=None, + _start=True, _sock=None, _server=None, sockets=None): + from waitress.server import create_server + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + bind_sockets=False, + sockets=_sockets) + return self.inst + def tearDown(self): self.inst.close() @@ -297,18 +367,35 @@ if hasattr(socket, 'AF_UNIX'): self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) -class DummySock(object): + def test_create_with_unix_socket(self): + from waitress.server import MultiSocketServer, BaseWSGIServer, \ + TcpWSGIServer, UnixWSGIServer + sockets = [ + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list(filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values())) + self.assertTrue(isinstance(server[0], UnixWSGIServer)) + self.assertTrue(isinstance(server[1], TcpWSGIServer)) + + +class DummySock(socket.socket): accepted = False blocking = False family = socket.AF_INET + type = socket.SOCK_STREAM + proto = 0 def __init__(self, toraise=None, acceptresult=(None, None)): self.toraise = toraise self.acceptresult = acceptresult self.bound = None self.opts = [] + self.bind_called = False def bind(self, addr): + self.bind_called = True self.bound = addr def accept(self): |