summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Krick <frank.krick@gmail.com>2018-11-13 16:01:17 -0500
committerFrank Krick <frank.krick@gmail.com>2018-11-13 16:01:17 -0500
commitb048393c33e77a6f369d3345b67c25ba191480d7 (patch)
tree0bd26d8ea718871c54012b24732838dda629c7e6
parent3e240bba7a3ed91792ede70ce6dbdc1c47ca0f62 (diff)
downloadwaitress-b048393c33e77a6f369d3345b67c25ba191480d7.tar.gz
Check in adjustments to prevent mixing of Internet and UNIX sockets as well as the use of unsupported sockets
-rw-r--r--docs/arguments.rst3
-rw-r--r--waitress/adjustments.py20
-rw-r--r--waitress/tests/test_adjustments.py26
-rw-r--r--waitress/tests/test_server.py4
4 files changed, 50 insertions, 3 deletions
diff --git a/docs/arguments.rst b/docs/arguments.rst
index b176c00..690daef 100644
--- a/docs/arguments.rst
+++ b/docs/arguments.rst
@@ -53,7 +53,8 @@ unix_socket_perms
sockets
.. versionadded:: 1.1.1
- A list of sockets. The sockets can be Internet or UNIX sockets and have to be bound.
+ A list of sockets. The sockets can be either Internet or UNIX sockets and have
+ to be bound. Internet and UNIX sockets cannot be mixed.
If the socket list is not empty, waitress creates one server for each socket.
Default is ``[]``.
diff --git a/waitress/adjustments.py b/waitress/adjustments.py
index eda3718..3ab3d00 100644
--- a/waitress/adjustments.py
+++ b/waitress/adjustments.py
@@ -326,6 +326,8 @@ class Adjustments(object):
self.listen = wanted_sockets
+ self.check_sockets(self.sockets)
+
@classmethod
def parse_args(cls, argv):
"""Pre-parse command line arguments for input into __init__. Note that
@@ -366,3 +368,21 @@ class Adjustments(object):
kw[param] = value
return kw, args
+
+ @classmethod
+ def check_sockets(cls, sockets):
+ has_unix_socket = False
+ has_inet_socket = False
+ has_unsupported_socket = False
+ for sock in sockets:
+ if (sock.family == socket.AF_INET or sock.family == socket.AF_INET6) and \
+ sock.type == socket.SOCK_STREAM:
+ has_inet_socket = True
+ elif sock.family == socket.AF_UNIX and sock.type == socket.SOCK_STREAM:
+ has_unix_socket = True
+ else:
+ has_unsupported_socket = True
+ if has_unix_socket and has_inet_socket:
+ raise ValueError('Internet and UNIX sockets may not be mixed.')
+ if has_unsupported_socket:
+ raise ValueError('Only Internet or UNIX stream sockets may be used.')
diff --git a/waitress/tests/test_adjustments.py b/waitress/tests/test_adjustments.py
index c945c1a..05b4dbd 100644
--- a/waitress/tests/test_adjustments.py
+++ b/waitress/tests/test_adjustments.py
@@ -285,6 +285,14 @@ class TestAdjustments(unittest.TestCase):
unix_socket='./tmp/test',
listen='127.0.0.1:8080')
+ def test_dont_use_unsupported_socket_types(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)]
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ sockets=sockets)
+ sockets[0].close()
+
def test_badvar(self):
self.assertRaises(ValueError, self._makeOne, nope=True)
@@ -378,3 +386,21 @@ class TestCLI(unittest.TestCase):
def test_bad_param(self):
import getopt
self.assertRaises(getopt.GetoptError, self.parse, ['--no-host'])
+
+
+if hasattr(socket, 'AF_UNIX'):
+ class TestUnixSocket(unittest.TestCase):
+ def _makeOne(self, **kw):
+ from waitress.adjustments import Adjustments
+ return Adjustments(**kw)
+
+ def test_dont_mix_internet_and_unix_sockets(self):
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ sockets=sockets)
+ sockets[0].close()
+ sockets[1].close()
diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py
index 1052f08..7cd6345 100644
--- a/waitress/tests/test_server.py
+++ b/waitress/tests/test_server.py
@@ -373,12 +373,12 @@ if hasattr(socket, 'AF_UNIX'):
TcpWSGIServer, UnixWSGIServer
sockets = [
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
- socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)]
inst = self._makeWithSockets(sockets=sockets, _start=False)
self.assertTrue(isinstance(inst, MultiSocketServer))
server = list(filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()))
self.assertTrue(isinstance(server[0], UnixWSGIServer))
- self.assertTrue(isinstance(server[1], TcpWSGIServer))
+ self.assertTrue(isinstance(server[1], UnixWSGIServer))
class DummySock(socket.socket):