diff options
author | Frank Krick <frank.krick@gmail.com> | 2018-11-13 16:01:17 -0500 |
---|---|---|
committer | Frank Krick <frank.krick@gmail.com> | 2018-11-13 16:01:17 -0500 |
commit | b048393c33e77a6f369d3345b67c25ba191480d7 (patch) | |
tree | 0bd26d8ea718871c54012b24732838dda629c7e6 | |
parent | 3e240bba7a3ed91792ede70ce6dbdc1c47ca0f62 (diff) | |
download | waitress-b048393c33e77a6f369d3345b67c25ba191480d7.tar.gz |
Check in adjustments to prevent mixing of Internet and UNIX sockets as well as the use of unsupported sockets
-rw-r--r-- | docs/arguments.rst | 3 | ||||
-rw-r--r-- | waitress/adjustments.py | 20 | ||||
-rw-r--r-- | waitress/tests/test_adjustments.py | 26 | ||||
-rw-r--r-- | waitress/tests/test_server.py | 4 |
4 files changed, 50 insertions, 3 deletions
diff --git a/docs/arguments.rst b/docs/arguments.rst index b176c00..690daef 100644 --- a/docs/arguments.rst +++ b/docs/arguments.rst @@ -53,7 +53,8 @@ unix_socket_perms sockets .. versionadded:: 1.1.1 - A list of sockets. The sockets can be Internet or UNIX sockets and have to be bound. + A list of sockets. The sockets can be either Internet or UNIX sockets and have + to be bound. Internet and UNIX sockets cannot be mixed. If the socket list is not empty, waitress creates one server for each socket. Default is ``[]``. diff --git a/waitress/adjustments.py b/waitress/adjustments.py index eda3718..3ab3d00 100644 --- a/waitress/adjustments.py +++ b/waitress/adjustments.py @@ -326,6 +326,8 @@ class Adjustments(object): self.listen = wanted_sockets + self.check_sockets(self.sockets) + @classmethod def parse_args(cls, argv): """Pre-parse command line arguments for input into __init__. Note that @@ -366,3 +368,21 @@ class Adjustments(object): kw[param] = value return kw, args + + @classmethod + def check_sockets(cls, sockets): + has_unix_socket = False + has_inet_socket = False + has_unsupported_socket = False + for sock in sockets: + if (sock.family == socket.AF_INET or sock.family == socket.AF_INET6) and \ + sock.type == socket.SOCK_STREAM: + has_inet_socket = True + elif sock.family == socket.AF_UNIX and sock.type == socket.SOCK_STREAM: + has_unix_socket = True + else: + has_unsupported_socket = True + if has_unix_socket and has_inet_socket: + raise ValueError('Internet and UNIX sockets may not be mixed.') + if has_unsupported_socket: + raise ValueError('Only Internet or UNIX stream sockets may be used.') diff --git a/waitress/tests/test_adjustments.py b/waitress/tests/test_adjustments.py index c945c1a..05b4dbd 100644 --- a/waitress/tests/test_adjustments.py +++ b/waitress/tests/test_adjustments.py @@ -285,6 +285,14 @@ class TestAdjustments(unittest.TestCase): unix_socket='./tmp/test', listen='127.0.0.1:8080') + def test_dont_use_unsupported_socket_types(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] + self.assertRaises( + ValueError, + self._makeOne, + sockets=sockets) + sockets[0].close() + def test_badvar(self): self.assertRaises(ValueError, self._makeOne, nope=True) @@ -378,3 +386,21 @@ class TestCLI(unittest.TestCase): def test_bad_param(self): import getopt self.assertRaises(getopt.GetoptError, self.parse, ['--no-host']) + + +if hasattr(socket, 'AF_UNIX'): + class TestUnixSocket(unittest.TestCase): + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + return Adjustments(**kw) + + def test_dont_mix_internet_and_unix_sockets(self): + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, + self._makeOne, + sockets=sockets) + sockets[0].close() + sockets[1].close() diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py index 1052f08..7cd6345 100644 --- a/waitress/tests/test_server.py +++ b/waitress/tests/test_server.py @@ -373,12 +373,12 @@ if hasattr(socket, 'AF_UNIX'): TcpWSGIServer, UnixWSGIServer sockets = [ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)] inst = self._makeWithSockets(sockets=sockets, _start=False) self.assertTrue(isinstance(inst, MultiSocketServer)) server = list(filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values())) self.assertTrue(isinstance(server[0], UnixWSGIServer)) - self.assertTrue(isinstance(server[1], TcpWSGIServer)) + self.assertTrue(isinstance(server[1], UnixWSGIServer)) class DummySock(socket.socket): |