summaryrefslogtreecommitdiff
path: root/waitress
diff options
context:
space:
mode:
authorFrank Krick <frank.krick@gmail.com>2018-10-28 16:03:09 -0400
committerFrank Krick <frank.krick@gmail.com>2018-10-28 16:03:09 -0400
commitde0a6c1ef714ce42c07a595238663931a056a9b7 (patch)
treed9e843a8d7d89c00fe4f35736adac903ffded04d /waitress
parent5274c8b95c47704462ab2c52d647c90230d5d801 (diff)
downloadwaitress-de0a6c1ef714ce42c07a595238663931a056a9b7.tar.gz
Added sockets and bind_sockets parameters for server creation
Diffstat (limited to 'waitress')
-rw-r--r--waitress/adjustments.py24
-rw-r--r--waitress/server.py90
-rw-r--r--waitress/tests/test_adjustments.py73
-rw-r--r--waitress/tests/test_server.py89
4 files changed, 242 insertions, 34 deletions
diff --git a/waitress/adjustments.py b/waitress/adjustments.py
index c55823a..92d5a4f 100644
--- a/waitress/adjustments.py
+++ b/waitress/adjustments.py
@@ -69,6 +69,11 @@ def slash_fixed_str(s):
def str_iftruthy(s):
return str(s) if s else None
+def as_socket_list(sockets):
+ """Checks if the elements in the list are of type socket and
+ returns None if not."""
+ return [sock for sock in sockets if isinstance(sock, socket.socket)]
+
class _str_marker(str):
pass
@@ -106,6 +111,8 @@ class Adjustments(object):
('asyncore_use_poll', asbool),
('unix_socket', str),
('unix_socket_perms', asoctal),
+ ('sockets', as_socket_list),
+ ('bind_sockets', asbool),
)
_param_map = dict(_params)
@@ -216,11 +223,28 @@ class Adjustments(object):
# Enable IPv6 by default
ipv6 = True
+ # A list of sockets that waitress will use to accept connections. They can
+ # be used for e.g. socket activation
+ sockets = []
+
+ # Enable binding to sockets by default. This can be turned off for sockets
+ # that are supplied from the outside, e.g. using socket activation
+ bind_sockets = True
+
def __init__(self, **kw):
if 'listen' in kw and ('host' in kw or 'port' in kw):
raise ValueError('host and or port may not be set if listen is set.')
+ if 'listen' in kw and 'sockets' in kw:
+ raise ValueError('socket may not be set if listen is set.')
+
+ if 'sockets' in kw and ('host' in kw or 'port' in kw):
+ raise ValueError('host and or port may not be set if sockets is set.')
+
+ if 'sockets' in kw and ('bind_sockets' not in kw or kw['bind_sockets']):
+ raise ValueError('Sockets passed should be bound already, please turn bind of with bind_sockets=False')
+
for k, v in kw.items():
if k not in self._param_map:
raise ValueError('Unknown adjustment %r' % k)
diff --git a/waitress/server.py b/waitress/server.py
index 198d5fa..ec12417 100644
--- a/waitress/server.py
+++ b/waitress/server.py
@@ -69,23 +69,47 @@ def create_server(application,
effective_listen = []
last_serv = None
- for sockinfo in adj.listen:
- # When TcpWSGIServer is called, it registers itself in the map. This
- # side-effect is all we need it for, so we don't store a reference to
- # or return it to the user.
- last_serv = TcpWSGIServer(
- application,
- map,
- _start,
- _sock,
- dispatcher=dispatcher,
- adj=adj,
- sockinfo=sockinfo)
- effective_listen.append((last_serv.effective_host, last_serv.effective_port))
+ if not adj.sockets:
+ for sockinfo in adj.listen:
+ # When TcpWSGIServer is called, it registers itself in the map. This
+ # side-effect is all we need it for, so we don't store a reference to
+ # or return it to the user.
+ last_serv = TcpWSGIServer(
+ application,
+ map,
+ _start,
+ _sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ sockinfo=sockinfo)
+ effective_listen.append((last_serv.effective_host, last_serv.effective_port))
+
+ for sock in adj.sockets:
+ if sock.family == socket.AF_INET or sock.family == socket.AF_INET6:
+ sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname())
+ last_serv = TcpWSGIServer(
+ application,
+ map,
+ _start,
+ sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ sockinfo=sockinfo)
+ effective_listen.append((last_serv.effective_host, last_serv.effective_port))
+ elif hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
+ last_serv = UnixWSGIServer(
+ application,
+ map,
+ _start,
+ sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ sockinfo=(sock.family, sock.type, sock.proto, sock.getsockname()))
+ effective_listen.append((last_serv.effective_host, last_serv.effective_port))
# We are running a single server, so we can just return the last server,
# saves us from having to create one more object
- if len(adj.listen) == 1:
+ if len(adj.listen) == 1 and len(adj.sockets) == 0 or len(adj.sockets) == 1:
# In this case we have no need to use a MultiSocketServer
return last_serv
@@ -181,7 +205,10 @@ class BaseWSGIServer(wasyncore.dispatcher, object):
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
self.set_reuse_addr()
- self.bind_server_socket()
+
+ if adj.bind_sockets:
+ self.bind_server_socket()
+
self.effective_host, self.effective_port = self.getsockname()
self.server_name = self.get_server_name(self.effective_host)
self.active_channels = {}
@@ -225,7 +252,21 @@ class BaseWSGIServer(wasyncore.dispatcher, object):
return server_name
def getsockname(self):
- raise NotImplementedError # pragma: no cover
+ try:
+ return self.socketmod.getnameinfo(
+ self.socket.getsockname(),
+ self.socketmod.NI_NUMERICSERV
+ )
+ except: # pragma: no cover
+ # This only happens on Linux because a DNS issue is considered a
+ # temporary failure that will raise (even when NI_NAMEREQD is not
+ # set). Instead we try again, but this time we just ask for the
+ # numerichost and the numericserv (port) and return those. It is
+ # better than nothing.
+ return self.socketmod.getnameinfo(
+ self.socket.getsockname(),
+ self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV
+ )
def accept_connections(self):
self.accepting = True
@@ -313,23 +354,6 @@ class TcpWSGIServer(BaseWSGIServer):
(_, _, _, sockaddr) = self.sockinfo
self.bind(sockaddr)
- def getsockname(self):
- try:
- return self.socketmod.getnameinfo(
- self.socket.getsockname(),
- self.socketmod.NI_NUMERICSERV
- )
- except: # pragma: no cover
- # This only happens on Linux because a DNS issue is considered a
- # temporary failure that will raise (even when NI_NAMEREQD is not
- # set). Instead we try again, but this time we just ask for the
- # numerichost and the numericserv (port) and return those. It is
- # better than nothing.
- return self.socketmod.getnameinfo(
- self.socket.getsockname(),
- self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV
- )
-
def set_socket_options(self, conn):
for (level, optname, value) in self.adj.socket_options:
conn.setsockopt(level, optname, value)
diff --git a/waitress/tests/test_adjustments.py b/waitress/tests/test_adjustments.py
index 09c60ef..05c4a23 100644
--- a/waitress/tests/test_adjustments.py
+++ b/waitress/tests/test_adjustments.py
@@ -49,6 +49,31 @@ class Test_asbool(unittest.TestCase):
result = self._callFUT(1)
self.assertEqual(result, True)
+class Test_as_socket_list(unittest.TestCase):
+
+ def test_only_sockets_in_list(self):
+ from waitress.adjustments import as_socket_list
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM)]
+ if hasattr(sockets, 'AF_UNIX'):
+ sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM))
+ new_sockets = as_socket_list(sockets)
+ self.assertEqual(sockets, new_sockets)
+ for sock in sockets:
+ sock.close()
+
+ def test_not_only_sockets_in_list(self):
+ from waitress.adjustments import as_socket_list
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM),
+ {'something': 'else'}]
+ new_sockets = as_socket_list(sockets)
+ self.assertEqual(new_sockets, [sockets[0], sockets[1]])
+ for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]:
+ sock.close()
+
class TestAdjustments(unittest.TestCase):
def _hasIPv6(self): # pragma: nocover
@@ -210,6 +235,54 @@ class TestAdjustments(unittest.TestCase):
listen='127.0.0.1:8080',
)
+ def test_good_sockets(self):
+ sockets = [
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ inst = self._makeOne(sockets=sockets, bind_sockets=False)
+ self.assertEqual(inst.sockets, sockets)
+ sockets[0].close()
+ sockets[1].close()
+
+ def test_dont_use_sockets_with_bind_enabled(self):
+ sockets = [
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ sockets=sockets,
+ bind_sockets=True)
+ sockets[0].close()
+ sockets[1].close()
+
+ def test_dont_mix_sockets_and_listen(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ listen='127.0.0.1:8080',
+ sockets=sockets)
+ sockets[0].close()
+
+ def test_dont_mix_sockets_and_host_port(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ host='localhost',
+ port='8080',
+ sockets=sockets)
+ sockets[0].close()
+
+ def test_default_bind_sockets(self):
+ inst = self._makeOne()
+ self.assertEqual(inst.bind_sockets, True)
+
+ def test_good_bind_sockets(self):
+ inst = self._makeOne(bind_sockets=False)
+ self.assertEqual(inst.bind_sockets, False)
+
def test_badvar(self):
self.assertRaises(ValueError, self._makeOne, nope=True)
diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py
index 0aa217a..38ee46d 100644
--- a/waitress/tests/test_server.py
+++ b/waitress/tests/test_server.py
@@ -50,6 +50,22 @@ class TestWSGIServer(unittest.TestCase):
_sock=sock)
return self.inst
+ def _makeWithSockets(self, application=dummy_app, _dispatcher=None, map=None,
+ _start=True, _sock=None, _server=None, sockets=None):
+ from waitress.server import create_server
+ _sockets = []
+ if sockets is not None:
+ _sockets = sockets
+ self.inst = create_server(
+ application,
+ map=map,
+ _dispatcher=_dispatcher,
+ _start=_start,
+ _sock=_sock,
+ bind_sockets=False,
+ sockets=_sockets)
+ return self.inst
+
def tearDown(self):
if self.inst is not None:
self.inst.close()
@@ -237,6 +253,44 @@ class TestWSGIServer(unittest.TestCase):
self.assertNotEqual(Adjustments.port, 1234)
self.assertEqual(self.inst.adj.port, 1234)
+ def test_create_with_one_tcp_socket(self):
+ from waitress.server import TcpWSGIServer
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ inst = self._makeWithSockets(_start=False, sockets=sockets)
+ self.assertTrue(isinstance(inst, TcpWSGIServer))
+
+ def test_create_with_multiple_tcp_sockets(self):
+ from waitress.server import MultiSocketServer
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM)]
+ inst = self._makeWithSockets(_start=False, sockets=sockets)
+ self.assertTrue(isinstance(inst, MultiSocketServer))
+ self.assertEqual(len(inst.effective_listen), 2)
+
+ def test_create_with_one_socket_should_not_bind_socket(self):
+ innersock = DummySock()
+ sockets = [DummySock(acceptresult=(innersock, None))]
+ sockets[0].bind(('127.0.0.1', 80))
+ sockets[0].bind_called = False
+ inst = self._makeWithSockets(_start=False, sockets=sockets)
+ self.assertEqual(inst.socket.bound, ('127.0.0.1', 80))
+ self.assertFalse(inst.socket.bind_called)
+
+ def test_create_with_one_socket_handle_accept_noerror(self):
+ innersock = DummySock()
+ sockets = [DummySock(acceptresult=(innersock, None))]
+ sockets[0].bind(('127.0.0.1', 80))
+ inst = self._makeWithSockets(sockets=sockets)
+ L = []
+ inst.channel_class = lambda *arg, **kw: L.append(arg)
+ inst.adj = DummyAdj
+ inst.handle_accept()
+ self.assertEqual(sockets[0].accepted, True)
+ self.assertEqual(innersock.opts, [('level', 'optname', 'value')])
+ self.assertEqual(L, [(inst, innersock, None, inst.adj)])
+
+
if hasattr(socket, 'AF_UNIX'):
class TestUnixWSGIServer(unittest.TestCase):
@@ -255,6 +309,22 @@ if hasattr(socket, 'AF_UNIX'):
)
return self.inst
+ def _makeWithSockets(self, application=dummy_app, _dispatcher=None, map=None,
+ _start=True, _sock=None, _server=None, sockets=None):
+ from waitress.server import create_server
+ _sockets = []
+ if sockets is not None:
+ _sockets = sockets
+ self.inst = create_server(
+ application,
+ map=map,
+ _dispatcher=_dispatcher,
+ _start=_start,
+ _sock=_sock,
+ bind_sockets=False,
+ sockets=_sockets)
+ return self.inst
+
def tearDown(self):
self.inst.close()
@@ -297,18 +367,35 @@ if hasattr(socket, 'AF_UNIX'):
self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX)
-class DummySock(object):
+ def test_create_with_unix_socket(self):
+ from waitress.server import MultiSocketServer, BaseWSGIServer, \
+ TcpWSGIServer, UnixWSGIServer
+ sockets = [
+ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ inst = self._makeWithSockets(sockets=sockets, _start=False)
+ self.assertTrue(isinstance(inst, MultiSocketServer))
+ server = list(filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()))
+ self.assertTrue(isinstance(server[0], UnixWSGIServer))
+ self.assertTrue(isinstance(server[1], TcpWSGIServer))
+
+
+class DummySock(socket.socket):
accepted = False
blocking = False
family = socket.AF_INET
+ type = socket.SOCK_STREAM
+ proto = 0
def __init__(self, toraise=None, acceptresult=(None, None)):
self.toraise = toraise
self.acceptresult = acceptresult
self.bound = None
self.opts = []
+ self.bind_called = False
def bind(self, addr):
+ self.bind_called = True
self.bound = addr
def accept(self):