From eb980ae020d9a4b16b719dd8a01737a32a5a01f2 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 22:36:59 -0800 Subject: Move tests to top directory --- tests/__init__.py | 2 + tests/fixtureapps/__init__.py | 1 + tests/fixtureapps/badcl.py | 11 + tests/fixtureapps/echo.py | 56 + tests/fixtureapps/error.py | 21 + tests/fixtureapps/filewrapper.py | 93 ++ tests/fixtureapps/getline.py | 17 + tests/fixtureapps/groundhog1.jpg | Bin 0 -> 45448 bytes tests/fixtureapps/nocl.py | 23 + tests/fixtureapps/runner.py | 6 + tests/fixtureapps/sleepy.py | 12 + tests/fixtureapps/toolarge.py | 7 + tests/fixtureapps/writecb.py | 14 + tests/test_adjustments.py | 481 ++++++++ tests/test_buffers.py | 523 +++++++++ tests/test_channel.py | 882 +++++++++++++++ tests/test_compat.py | 22 + tests/test_functional.py | 1667 +++++++++++++++++++++++++++ tests/test_init.py | 51 + tests/test_parser.py | 732 ++++++++++++ tests/test_proxy_headers.py | 724 ++++++++++++ tests/test_receiver.py | 242 ++++ tests/test_regression.py | 147 +++ tests/test_runner.py | 191 ++++ tests/test_server.py | 533 +++++++++ tests/test_task.py | 1001 ++++++++++++++++ tests/test_trigger.py | 111 ++ tests/test_utilities.py | 140 +++ tests/test_wasyncore.py | 1761 +++++++++++++++++++++++++++++ waitress/tests/__init__.py | 2 - waitress/tests/fixtureapps/__init__.py | 1 - waitress/tests/fixtureapps/badcl.py | 11 - waitress/tests/fixtureapps/echo.py | 56 - waitress/tests/fixtureapps/error.py | 21 - waitress/tests/fixtureapps/filewrapper.py | 93 -- waitress/tests/fixtureapps/getline.py | 17 - waitress/tests/fixtureapps/groundhog1.jpg | Bin 45448 -> 0 bytes waitress/tests/fixtureapps/nocl.py | 23 - waitress/tests/fixtureapps/runner.py | 6 - waitress/tests/fixtureapps/sleepy.py | 12 - waitress/tests/fixtureapps/toolarge.py | 7 - waitress/tests/fixtureapps/writecb.py | 14 - waitress/tests/test_adjustments.py | 481 -------- waitress/tests/test_buffers.py | 523 --------- waitress/tests/test_channel.py | 882 --------------- waitress/tests/test_compat.py | 22 - waitress/tests/test_functional.py | 1667 --------------------------- waitress/tests/test_init.py | 51 - waitress/tests/test_parser.py | 732 ------------ waitress/tests/test_proxy_headers.py | 724 ------------ waitress/tests/test_receiver.py | 242 ---- waitress/tests/test_regression.py | 147 --- waitress/tests/test_runner.py | 191 ---- waitress/tests/test_server.py | 533 --------- waitress/tests/test_task.py | 1001 ---------------- waitress/tests/test_trigger.py | 111 -- waitress/tests/test_utilities.py | 140 --- waitress/tests/test_wasyncore.py | 1761 ----------------------------- 58 files changed, 9471 insertions(+), 9471 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/fixtureapps/__init__.py create mode 100644 tests/fixtureapps/badcl.py create mode 100644 tests/fixtureapps/echo.py create mode 100644 tests/fixtureapps/error.py create mode 100644 tests/fixtureapps/filewrapper.py create mode 100644 tests/fixtureapps/getline.py create mode 100644 tests/fixtureapps/groundhog1.jpg create mode 100644 tests/fixtureapps/nocl.py create mode 100644 tests/fixtureapps/runner.py create mode 100644 tests/fixtureapps/sleepy.py create mode 100644 tests/fixtureapps/toolarge.py create mode 100644 tests/fixtureapps/writecb.py create mode 100644 tests/test_adjustments.py create mode 100644 tests/test_buffers.py create mode 100644 tests/test_channel.py create mode 100644 tests/test_compat.py create mode 100644 tests/test_functional.py create mode 100644 tests/test_init.py create mode 100644 tests/test_parser.py create mode 100644 tests/test_proxy_headers.py create mode 100644 tests/test_receiver.py create mode 100644 tests/test_regression.py create mode 100644 tests/test_runner.py create mode 100644 tests/test_server.py create mode 100644 tests/test_task.py create mode 100644 tests/test_trigger.py create mode 100644 tests/test_utilities.py create mode 100644 tests/test_wasyncore.py delete mode 100644 waitress/tests/__init__.py delete mode 100644 waitress/tests/fixtureapps/__init__.py delete mode 100644 waitress/tests/fixtureapps/badcl.py delete mode 100644 waitress/tests/fixtureapps/echo.py delete mode 100644 waitress/tests/fixtureapps/error.py delete mode 100644 waitress/tests/fixtureapps/filewrapper.py delete mode 100644 waitress/tests/fixtureapps/getline.py delete mode 100644 waitress/tests/fixtureapps/groundhog1.jpg delete mode 100644 waitress/tests/fixtureapps/nocl.py delete mode 100644 waitress/tests/fixtureapps/runner.py delete mode 100644 waitress/tests/fixtureapps/sleepy.py delete mode 100644 waitress/tests/fixtureapps/toolarge.py delete mode 100644 waitress/tests/fixtureapps/writecb.py delete mode 100644 waitress/tests/test_adjustments.py delete mode 100644 waitress/tests/test_buffers.py delete mode 100644 waitress/tests/test_channel.py delete mode 100644 waitress/tests/test_compat.py delete mode 100644 waitress/tests/test_functional.py delete mode 100644 waitress/tests/test_init.py delete mode 100644 waitress/tests/test_parser.py delete mode 100644 waitress/tests/test_proxy_headers.py delete mode 100644 waitress/tests/test_receiver.py delete mode 100644 waitress/tests/test_regression.py delete mode 100644 waitress/tests/test_runner.py delete mode 100644 waitress/tests/test_server.py delete mode 100644 waitress/tests/test_task.py delete mode 100644 waitress/tests/test_trigger.py delete mode 100644 waitress/tests/test_utilities.py delete mode 100644 waitress/tests/test_wasyncore.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..b711d36 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# +# This file is necessary to make this directory a package. diff --git a/tests/fixtureapps/__init__.py b/tests/fixtureapps/__init__.py new file mode 100644 index 0000000..f215a2b --- /dev/null +++ b/tests/fixtureapps/__init__.py @@ -0,0 +1 @@ +# package (for -m) diff --git a/tests/fixtureapps/badcl.py b/tests/fixtureapps/badcl.py new file mode 100644 index 0000000..24067de --- /dev/null +++ b/tests/fixtureapps/badcl.py @@ -0,0 +1,11 @@ +def app(environ, start_response): # pragma: no cover + body = b"abcdefghi" + cl = len(body) + if environ["PATH_INFO"] == "/short_body": + cl = len(body) + 1 + if environ["PATH_INFO"] == "/long_body": + cl = len(body) - 1 + start_response( + "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] + ) + return [body] diff --git a/tests/fixtureapps/echo.py b/tests/fixtureapps/echo.py new file mode 100644 index 0000000..813bdac --- /dev/null +++ b/tests/fixtureapps/echo.py @@ -0,0 +1,56 @@ +from collections import namedtuple +import json + + +def app_body_only(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + cl = str(len(body)) + start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain"),]) + return [body] + + +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + request_body = environ["wsgi.input"].read(cl) + cl = str(len(request_body)) + meta = { + "method": environ["REQUEST_METHOD"], + "path_info": environ["PATH_INFO"], + "script_name": environ["SCRIPT_NAME"], + "query_string": environ["QUERY_STRING"], + "content_length": cl, + "scheme": environ["wsgi.url_scheme"], + "remote_addr": environ["REMOTE_ADDR"], + "remote_host": environ["REMOTE_HOST"], + "server_port": environ["SERVER_PORT"], + "server_name": environ["SERVER_NAME"], + "headers": { + k[len("HTTP_") :]: v for k, v in environ.items() if k.startswith("HTTP_") + }, + } + response = json.dumps(meta).encode("utf8") + b"\r\n\r\n" + request_body + start_response( + "200 OK", + [("Content-Length", str(len(response))), ("Content-Type", "text/plain"),], + ) + return [response] + + +Echo = namedtuple( + "Echo", + ( + "method path_info script_name query_string content_length scheme " + "remote_addr remote_host server_port server_name headers body" + ), +) + + +def parse_response(response): + meta, body = response.split(b"\r\n\r\n", 1) + meta = json.loads(meta.decode("utf8")) + return Echo(body=body, **meta) diff --git a/tests/fixtureapps/error.py b/tests/fixtureapps/error.py new file mode 100644 index 0000000..5afb1c5 --- /dev/null +++ b/tests/fixtureapps/error.py @@ -0,0 +1,21 @@ +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + cl = str(len(body)) + if environ["PATH_INFO"] == "/before_start_response": + raise ValueError("wrong") + write = start_response( + "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] + ) + if environ["PATH_INFO"] == "/after_write_cb": + write("abc") + if environ["PATH_INFO"] == "/in_generator": + + def foo(): + yield "abc" + raise ValueError + + return foo() + raise ValueError("wrong") diff --git a/tests/fixtureapps/filewrapper.py b/tests/fixtureapps/filewrapper.py new file mode 100644 index 0000000..63df5a6 --- /dev/null +++ b/tests/fixtureapps/filewrapper.py @@ -0,0 +1,93 @@ +import io +import os + +here = os.path.dirname(os.path.abspath(__file__)) +fn = os.path.join(here, "groundhog1.jpg") + + +class KindaFilelike(object): # pragma: no cover + def __init__(self, bytes): + self.bytes = bytes + + def read(self, n): + bytes = self.bytes[:n] + self.bytes = self.bytes[n:] + return bytes + + +class UnseekableIOBase(io.RawIOBase): # pragma: no cover + def __init__(self, bytes): + self.buf = io.BytesIO(bytes) + + def writable(self): + return False + + def readable(self): + return True + + def seekable(self): + return False + + def read(self, n): + return self.buf.read(n) + + +def app(environ, start_response): # pragma: no cover + path_info = environ["PATH_INFO"] + if path_info.startswith("/filelike"): + f = open(fn, "rb") + f.seek(0, 2) + cl = f.tell() + f.seek(0) + if path_info == "/filelike": + headers = [ + ("Content-Length", str(cl)), + ("Content-Type", "image/jpeg"), + ] + elif path_info == "/filelike_nocl": + headers = [("Content-Type", "image/jpeg")] + elif path_info == "/filelike_shortcl": + # short content length + headers = [ + ("Content-Length", "1"), + ("Content-Type", "image/jpeg"), + ] + else: + # long content length (/filelike_longcl) + headers = [ + ("Content-Length", str(cl + 10)), + ("Content-Type", "image/jpeg"), + ] + else: + with open(fn, "rb") as fp: + data = fp.read() + cl = len(data) + f = KindaFilelike(data) + if path_info == "/notfilelike": + headers = [ + ("Content-Length", str(len(data))), + ("Content-Type", "image/jpeg"), + ] + elif path_info == "/notfilelike_iobase": + headers = [ + ("Content-Length", str(len(data))), + ("Content-Type", "image/jpeg"), + ] + f = UnseekableIOBase(data) + elif path_info == "/notfilelike_nocl": + headers = [("Content-Type", "image/jpeg")] + elif path_info == "/notfilelike_shortcl": + # short content length + headers = [ + ("Content-Length", "1"), + ("Content-Type", "image/jpeg"), + ] + else: + # long content length (/notfilelike_longcl) + headers = [ + ("Content-Length", str(cl + 10)), + ("Content-Type", "image/jpeg"), + ] + + start_response("200 OK", headers) + return environ["wsgi.file_wrapper"](f, 8192) diff --git a/tests/fixtureapps/getline.py b/tests/fixtureapps/getline.py new file mode 100644 index 0000000..5e0ad3a --- /dev/null +++ b/tests/fixtureapps/getline.py @@ -0,0 +1,17 @@ +import sys + +if __name__ == "__main__": + try: + from urllib.request import urlopen, URLError + except ImportError: + from urllib2 import urlopen, URLError + + url = sys.argv[1] + headers = {"Content-Type": "text/plain; charset=utf-8"} + try: + resp = urlopen(url) + line = resp.readline().decode("ascii") # py3 + except URLError: + line = "failed to read %s" % url + sys.stdout.write(line) + sys.stdout.flush() diff --git a/tests/fixtureapps/groundhog1.jpg b/tests/fixtureapps/groundhog1.jpg new file mode 100644 index 0000000..90f610e Binary files /dev/null and b/tests/fixtureapps/groundhog1.jpg differ diff --git a/tests/fixtureapps/nocl.py b/tests/fixtureapps/nocl.py new file mode 100644 index 0000000..f82bba0 --- /dev/null +++ b/tests/fixtureapps/nocl.py @@ -0,0 +1,23 @@ +def chunks(l, n): # pragma: no cover + """ Yield successive n-sized chunks from l. + """ + for i in range(0, len(l), n): + yield l[i : i + n] + + +def gen(body): # pragma: no cover + for chunk in chunks(body, 10): + yield chunk + + +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + start_response("200 OK", [("Content-Type", "text/plain")]) + if environ["PATH_INFO"] == "/list": + return [body] + if environ["PATH_INFO"] == "/list_lentwo": + return [body[0:1], body[1:]] + return gen(body) diff --git a/tests/fixtureapps/runner.py b/tests/fixtureapps/runner.py new file mode 100644 index 0000000..1d66ad1 --- /dev/null +++ b/tests/fixtureapps/runner.py @@ -0,0 +1,6 @@ +def app(): # pragma: no cover + return None + + +def returns_app(): # pragma: no cover + return app diff --git a/tests/fixtureapps/sleepy.py b/tests/fixtureapps/sleepy.py new file mode 100644 index 0000000..2d171d8 --- /dev/null +++ b/tests/fixtureapps/sleepy.py @@ -0,0 +1,12 @@ +import time + + +def app(environ, start_response): # pragma: no cover + if environ["PATH_INFO"] == "/sleepy": + time.sleep(2) + body = b"sleepy returned" + else: + body = b"notsleepy returned" + cl = str(len(body)) + start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]) + return [body] diff --git a/tests/fixtureapps/toolarge.py b/tests/fixtureapps/toolarge.py new file mode 100644 index 0000000..a0f36d2 --- /dev/null +++ b/tests/fixtureapps/toolarge.py @@ -0,0 +1,7 @@ +def app(environ, start_response): # pragma: no cover + body = b"abcdef" + cl = len(body) + start_response( + "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] + ) + return [body] diff --git a/tests/fixtureapps/writecb.py b/tests/fixtureapps/writecb.py new file mode 100644 index 0000000..e1d2792 --- /dev/null +++ b/tests/fixtureapps/writecb.py @@ -0,0 +1,14 @@ +def app(environ, start_response): # pragma: no cover + path_info = environ["PATH_INFO"] + if path_info == "/no_content_length": + headers = [] + else: + headers = [("Content-Length", "9")] + write = start_response("200 OK", headers) + if path_info == "/long_body": + write(b"abcdefghij") + elif path_info == "/short_body": + write(b"abcdefgh") + else: + write(b"abcdefghi") + return [] diff --git a/tests/test_adjustments.py b/tests/test_adjustments.py new file mode 100644 index 0000000..303c1aa --- /dev/null +++ b/tests/test_adjustments.py @@ -0,0 +1,481 @@ +import sys +import socket +import warnings + +from waitress.compat import ( + PY2, + WIN, +) + +if sys.version_info[:2] == (2, 6): # pragma: no cover + import unittest2 as unittest +else: # pragma: no cover + import unittest + + +class Test_asbool(unittest.TestCase): + def _callFUT(self, s): + from waitress.adjustments import asbool + + return asbool(s) + + def test_s_is_None(self): + result = self._callFUT(None) + self.assertEqual(result, False) + + def test_s_is_True(self): + result = self._callFUT(True) + self.assertEqual(result, True) + + def test_s_is_False(self): + result = self._callFUT(False) + self.assertEqual(result, False) + + def test_s_is_true(self): + result = self._callFUT("True") + self.assertEqual(result, True) + + def test_s_is_false(self): + result = self._callFUT("False") + self.assertEqual(result, False) + + def test_s_is_yes(self): + result = self._callFUT("yes") + self.assertEqual(result, True) + + def test_s_is_on(self): + result = self._callFUT("on") + self.assertEqual(result, True) + + def test_s_is_1(self): + result = self._callFUT(1) + self.assertEqual(result, True) + + +class Test_as_socket_list(unittest.TestCase): + def test_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + ] + if hasattr(socket, "AF_UNIX"): + sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) + new_sockets = as_socket_list(sockets) + self.assertEqual(sockets, new_sockets) + for sock in sockets: + sock.close() + + def test_not_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + {"something": "else"}, + ] + new_sockets = as_socket_list(sockets) + self.assertEqual(new_sockets, [sockets[0], sockets[1]]) + for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]: + sock.close() + + +class TestAdjustments(unittest.TestCase): + def _hasIPv6(self): # pragma: nocover + if not socket.has_ipv6: + return False + + try: + socket.getaddrinfo( + "::1", + 0, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + + return True + except socket.gaierror as e: + # Check to see what the error is + if e.errno == socket.EAI_ADDRFAMILY: + return False + else: + raise e + + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_goodvars(self): + inst = self._makeOne( + host="localhost", + port="8080", + threads="5", + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"forwarded"}, + trusted_proxy_count=2, + log_untrusted_proxy_headers=True, + url_scheme="https", + backlog="20", + recv_bytes="200", + send_bytes="300", + outbuf_overflow="400", + inbuf_overflow="500", + connection_limit="1000", + cleanup_interval="1100", + channel_timeout="1200", + log_socket_errors="true", + max_request_header_size="1300", + max_request_body_size="1400", + expose_tracebacks="true", + ident="abc", + asyncore_loop_timeout="5", + asyncore_use_poll=True, + unix_socket_perms="777", + url_prefix="///foo/", + ipv4=True, + ipv6=False, + ) + + self.assertEqual(inst.host, "localhost") + self.assertEqual(inst.port, 8080) + self.assertEqual(inst.threads, 5) + self.assertEqual(inst.trusted_proxy, "192.168.1.1") + self.assertEqual(inst.trusted_proxy_headers, {"forwarded"}) + self.assertEqual(inst.trusted_proxy_count, 2) + self.assertEqual(inst.log_untrusted_proxy_headers, True) + self.assertEqual(inst.url_scheme, "https") + self.assertEqual(inst.backlog, 20) + self.assertEqual(inst.recv_bytes, 200) + self.assertEqual(inst.send_bytes, 300) + self.assertEqual(inst.outbuf_overflow, 400) + self.assertEqual(inst.inbuf_overflow, 500) + self.assertEqual(inst.connection_limit, 1000) + self.assertEqual(inst.cleanup_interval, 1100) + self.assertEqual(inst.channel_timeout, 1200) + self.assertEqual(inst.log_socket_errors, True) + self.assertEqual(inst.max_request_header_size, 1300) + self.assertEqual(inst.max_request_body_size, 1400) + self.assertEqual(inst.expose_tracebacks, True) + self.assertEqual(inst.asyncore_loop_timeout, 5) + self.assertEqual(inst.asyncore_use_poll, True) + self.assertEqual(inst.ident, "abc") + self.assertEqual(inst.unix_socket_perms, 0o777) + self.assertEqual(inst.url_prefix, "/foo") + self.assertEqual(inst.ipv4, True) + self.assertEqual(inst.ipv6, False) + + bind_pairs = [ + sockaddr[:2] + for (family, _, _, sockaddr) in inst.listen + if family == socket.AF_INET + ] + + # On Travis, somehow we start listening to two sockets when resolving + # localhost... + self.assertEqual(("127.0.0.1", 8080), bind_pairs[0]) + + def test_goodvar_listen(self): + inst = self._makeOne(listen="127.0.0.1") + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 8080)]) + + def test_default_listen(self): + inst = self._makeOne() + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [("0.0.0.0", 8080)]) + + def test_multiple_listen(self): + inst = self._makeOne(listen="127.0.0.1:9090 127.0.0.1:8080") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 9090), ("127.0.0.1", 8080)]) + + def test_wildcard_listen(self): + inst = self._makeOne(listen="*:8080") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertTrue(len(bind_pairs) >= 1) + + def test_ipv6_no_port(self): # pragma: nocover + if not self._hasIPv6(): + return + + inst = self._makeOne(listen="[::1]") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("::1", 8080)]) + + def test_bad_port(self): + self.assertRaises(ValueError, self._makeOne, listen="127.0.0.1:test") + + def test_service_port(self): + if WIN and PY2: # pragma: no cover + # On Windows and Python 2 this is broken, so we raise a ValueError + self.assertRaises( + ValueError, self._makeOne, listen="127.0.0.1:http", + ) + return + + inst = self._makeOne(listen="127.0.0.1:http 0.0.0.0:https") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 80), ("0.0.0.0", 443)]) + + def test_dont_mix_host_port_listen(self): + self.assertRaises( + ValueError, + self._makeOne, + host="localhost", + port="8080", + listen="127.0.0.1:8080", + ) + + def test_good_sockets(self): + sockets = [ + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + inst = self._makeOne(sockets=sockets) + self.assertEqual(inst.sockets, sockets) + sockets[0].close() + sockets[1].close() + + def test_dont_mix_sockets_and_listen(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, listen="127.0.0.1:8080", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_sockets_and_host_port(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, host="localhost", port="8080", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_sockets_and_unix_socket(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, unix_socket="./tmp/test", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_unix_socket_and_host_port(self): + self.assertRaises( + ValueError, + self._makeOne, + unix_socket="./tmp/test", + host="localhost", + port="8080", + ) + + def test_dont_mix_unix_socket_and_listen(self): + self.assertRaises( + ValueError, self._makeOne, unix_socket="./tmp/test", listen="127.0.0.1:8080" + ) + + def test_dont_use_unsupported_socket_types(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + + def test_dont_mix_forwarded_with_x_forwarded(self): + with self.assertRaises(ValueError) as cm: + self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers={"forwarded", "x-forwarded-for"}, + ) + + self.assertIn("The Forwarded proxy header", str(cm.exception)) + + def test_unknown_trusted_proxy_header(self): + with self.assertRaises(ValueError) as cm: + self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers={"forwarded", "x-forwarded-unknown"}, + ) + + self.assertIn( + "unknown trusted_proxy_headers value (x-forwarded-unknown)", + str(cm.exception), + ) + + def test_trusted_proxy_count_no_trusted_proxy(self): + with self.assertRaises(ValueError) as cm: + self._makeOne(trusted_proxy_count=1) + + self.assertIn("trusted_proxy_count has no meaning", str(cm.exception)) + + def test_trusted_proxy_headers_no_trusted_proxy(self): + with self.assertRaises(ValueError) as cm: + self._makeOne(trusted_proxy_headers={"forwarded"}) + + self.assertIn("trusted_proxy_headers has no meaning", str(cm.exception)) + + def test_trusted_proxy_headers_string_list(self): + inst = self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers="x-forwarded-for x-forwarded-by", + ) + self.assertEqual( + inst.trusted_proxy_headers, {"x-forwarded-for", "x-forwarded-by"} + ) + + def test_trusted_proxy_headers_string_list_newlines(self): + inst = self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers="x-forwarded-for\nx-forwarded-by\nx-forwarded-host", + ) + self.assertEqual( + inst.trusted_proxy_headers, + {"x-forwarded-for", "x-forwarded-by", "x-forwarded-host"}, + ) + + def test_no_trusted_proxy_headers_trusted_proxy(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne(trusted_proxy="localhost") + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("Implicitly trusting X-Forwarded-Proto", str(w[0])) + + def test_clear_untrusted_proxy_headers(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne( + trusted_proxy="localhost", trusted_proxy_headers={"x-forwarded-for"} + ) + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn( + "clear_untrusted_proxy_headers will be set to True", str(w[0]) + ) + + def test_deprecated_send_bytes(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne(send_bytes=1) + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("send_bytes", str(w[0])) + + def test_badvar(self): + self.assertRaises(ValueError, self._makeOne, nope=True) + + def test_ipv4_disabled(self): + self.assertRaises( + ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080" + ) + + def test_ipv6_disabled(self): + self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") + + def test_server_header_removable(self): + inst = self._makeOne(ident=None) + self.assertEqual(inst.ident, None) + + inst = self._makeOne(ident="") + self.assertEqual(inst.ident, None) + + inst = self._makeOne(ident="specific_header") + self.assertEqual(inst.ident, "specific_header") + + +class TestCLI(unittest.TestCase): + def parse(self, argv): + from waitress.adjustments import Adjustments + + return Adjustments.parse_args(argv) + + def test_noargs(self): + opts, args = self.parse([]) + self.assertDictEqual(opts, {"call": False, "help": False}) + self.assertSequenceEqual(args, []) + + def test_help(self): + opts, args = self.parse(["--help"]) + self.assertDictEqual(opts, {"call": False, "help": True}) + self.assertSequenceEqual(args, []) + + def test_call(self): + opts, args = self.parse(["--call"]) + self.assertDictEqual(opts, {"call": True, "help": False}) + self.assertSequenceEqual(args, []) + + def test_both(self): + opts, args = self.parse(["--call", "--help"]) + self.assertDictEqual(opts, {"call": True, "help": True}) + self.assertSequenceEqual(args, []) + + def test_positive_boolean(self): + opts, args = self.parse(["--expose-tracebacks"]) + self.assertDictContainsSubset({"expose_tracebacks": "true"}, opts) + self.assertSequenceEqual(args, []) + + def test_negative_boolean(self): + opts, args = self.parse(["--no-expose-tracebacks"]) + self.assertDictContainsSubset({"expose_tracebacks": "false"}, opts) + self.assertSequenceEqual(args, []) + + def test_cast_params(self): + opts, args = self.parse( + ["--host=localhost", "--port=80", "--unix-socket-perms=777"] + ) + self.assertDictContainsSubset( + {"host": "localhost", "port": "80", "unix_socket_perms": "777",}, opts + ) + self.assertSequenceEqual(args, []) + + def test_listen_params(self): + opts, args = self.parse(["--listen=test:80",]) + + self.assertDictContainsSubset({"listen": " test:80"}, opts) + self.assertSequenceEqual(args, []) + + def test_multiple_listen_params(self): + opts, args = self.parse(["--listen=test:80", "--listen=test:8080",]) + + self.assertDictContainsSubset({"listen": " test:80 test:8080"}, opts) + self.assertSequenceEqual(args, []) + + def test_bad_param(self): + import getopt + + self.assertRaises(getopt.GetoptError, self.parse, ["--no-host"]) + + +if hasattr(socket, "AF_UNIX"): + + class TestUnixSocket(unittest.TestCase): + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_dont_mix_internet_and_unix_sockets(self): + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + sockets[1].close() diff --git a/tests/test_buffers.py b/tests/test_buffers.py new file mode 100644 index 0000000..a1330ac --- /dev/null +++ b/tests/test_buffers.py @@ -0,0 +1,523 @@ +import unittest +import io + + +class TestFileBasedBuffer(unittest.TestCase): + def _makeOne(self, file=None, from_buffer=None): + from waitress.buffers import FileBasedBuffer + + buf = FileBasedBuffer(file, from_buffer=from_buffer) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test_ctor_from_buffer_None(self): + inst = self._makeOne("file") + self.assertEqual(inst.file, "file") + + def test_ctor_from_buffer(self): + from_buffer = io.BytesIO(b"data") + from_buffer.getfile = lambda *x: from_buffer + f = io.BytesIO() + inst = self._makeOne(f, from_buffer) + self.assertEqual(inst.file, f) + del from_buffer.getfile + self.assertEqual(inst.remain, 4) + from_buffer.close() + + def test___len__(self): + inst = self._makeOne() + inst.remain = 10 + self.assertEqual(len(inst), 10) + + def test___nonzero__(self): + inst = self._makeOne() + inst.remain = 10 + self.assertEqual(bool(inst), True) + inst.remain = 0 + self.assertEqual(bool(inst), True) + + def test_append(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + inst.append(b"data2") + self.assertEqual(f.getvalue(), b"datadata2") + self.assertEqual(inst.remain, 5) + + def test_get_skip_true(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + result = inst.get(100, skip=True) + self.assertEqual(result, b"data") + self.assertEqual(inst.remain, -4) + + def test_get_skip_false(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + result = inst.get(100, skip=False) + self.assertEqual(result, b"data") + self.assertEqual(inst.remain, 0) + + def test_get_skip_bytes_less_than_zero(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + result = inst.get(-1, skip=False) + self.assertEqual(result, b"data") + self.assertEqual(inst.remain, 0) + + def test_skip_remain_gt_bytes(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + inst.remain = 1 + inst.skip(1) + self.assertEqual(inst.remain, 0) + + def test_skip_remain_lt_bytes(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + inst.remain = 1 + self.assertRaises(ValueError, inst.skip, 2) + + def test_newfile(self): + inst = self._makeOne() + self.assertRaises(NotImplementedError, inst.newfile) + + def test_prune_remain_notzero(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + inst.remain = 1 + nf = io.BytesIO() + inst.newfile = lambda *x: nf + inst.prune() + self.assertTrue(inst.file is not f) + self.assertEqual(nf.getvalue(), b"d") + + def test_prune_remain_zero_tell_notzero(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + nf = io.BytesIO(b"d") + inst.newfile = lambda *x: nf + inst.remain = 0 + inst.prune() + self.assertTrue(inst.file is not f) + self.assertEqual(nf.getvalue(), b"d") + + def test_prune_remain_zero_tell_zero(self): + f = io.BytesIO() + inst = self._makeOne(f) + inst.remain = 0 + inst.prune() + self.assertTrue(inst.file is f) + + def test_close(self): + f = io.BytesIO() + inst = self._makeOne(f) + inst.close() + self.assertTrue(f.closed) + self.buffers_to_close.remove(inst) + + +class TestTempfileBasedBuffer(unittest.TestCase): + def _makeOne(self, from_buffer=None): + from waitress.buffers import TempfileBasedBuffer + + buf = TempfileBasedBuffer(from_buffer=from_buffer) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test_newfile(self): + inst = self._makeOne() + r = inst.newfile() + self.assertTrue(hasattr(r, "fileno")) # file + r.close() + + +class TestBytesIOBasedBuffer(unittest.TestCase): + def _makeOne(self, from_buffer=None): + from waitress.buffers import BytesIOBasedBuffer + + return BytesIOBasedBuffer(from_buffer=from_buffer) + + def test_ctor_from_buffer_not_None(self): + f = io.BytesIO() + f.getfile = lambda *x: f + inst = self._makeOne(f) + self.assertTrue(hasattr(inst.file, "read")) + + def test_ctor_from_buffer_None(self): + inst = self._makeOne() + self.assertTrue(hasattr(inst.file, "read")) + + def test_newfile(self): + inst = self._makeOne() + r = inst.newfile() + self.assertTrue(hasattr(r, "read")) + + +class TestReadOnlyFileBasedBuffer(unittest.TestCase): + def _makeOne(self, file, block_size=8192): + from waitress.buffers import ReadOnlyFileBasedBuffer + + buf = ReadOnlyFileBasedBuffer(file, block_size) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test_prepare_not_seekable(self): + f = KindaFilelike(b"abc") + inst = self._makeOne(f) + result = inst.prepare() + self.assertEqual(result, False) + self.assertEqual(inst.remain, 0) + + def test_prepare_not_seekable_closeable(self): + f = KindaFilelike(b"abc", close=1) + inst = self._makeOne(f) + result = inst.prepare() + self.assertEqual(result, False) + self.assertEqual(inst.remain, 0) + self.assertTrue(hasattr(inst, "close")) + + def test_prepare_seekable_closeable(self): + f = Filelike(b"abc", close=1, tellresults=[0, 10]) + inst = self._makeOne(f) + result = inst.prepare() + self.assertEqual(result, 10) + self.assertEqual(inst.remain, 10) + self.assertEqual(inst.file.seeked, 0) + self.assertTrue(hasattr(inst, "close")) + + def test_get_numbytes_neg_one(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(-1) + self.assertEqual(result, b"ab") + self.assertEqual(inst.remain, 2) + self.assertEqual(f.tell(), 0) + + def test_get_numbytes_gt_remain(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(3) + self.assertEqual(result, b"ab") + self.assertEqual(inst.remain, 2) + self.assertEqual(f.tell(), 0) + + def test_get_numbytes_lt_remain(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(1) + self.assertEqual(result, b"a") + self.assertEqual(inst.remain, 2) + self.assertEqual(f.tell(), 0) + + def test_get_numbytes_gt_remain_withskip(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(3, skip=True) + self.assertEqual(result, b"ab") + self.assertEqual(inst.remain, 0) + self.assertEqual(f.tell(), 2) + + def test_get_numbytes_lt_remain_withskip(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(1, skip=True) + self.assertEqual(result, b"a") + self.assertEqual(inst.remain, 1) + self.assertEqual(f.tell(), 1) + + def test___iter__(self): + data = b"a" * 10000 + f = io.BytesIO(data) + inst = self._makeOne(f) + r = b"" + for val in inst: + r += val + self.assertEqual(r, data) + + def test_append(self): + inst = self._makeOne(None) + self.assertRaises(NotImplementedError, inst.append, "a") + + +class TestOverflowableBuffer(unittest.TestCase): + def _makeOne(self, overflow=10): + from waitress.buffers import OverflowableBuffer + + buf = OverflowableBuffer(overflow) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test___len__buf_is_None(self): + inst = self._makeOne() + self.assertEqual(len(inst), 0) + + def test___len__buf_is_not_None(self): + inst = self._makeOne() + inst.buf = b"abc" + self.assertEqual(len(inst), 3) + self.buffers_to_close.remove(inst) + + def test___nonzero__(self): + inst = self._makeOne() + inst.buf = b"abc" + self.assertEqual(bool(inst), True) + inst.buf = b"" + self.assertEqual(bool(inst), False) + self.buffers_to_close.remove(inst) + + def test___nonzero___on_int_overflow_buffer(self): + inst = self._makeOne() + + class int_overflow_buf(bytes): + def __len__(self): + # maxint + 1 + return 0x7FFFFFFFFFFFFFFF + 1 + + inst.buf = int_overflow_buf() + self.assertEqual(bool(inst), True) + inst.buf = b"" + self.assertEqual(bool(inst), False) + self.buffers_to_close.remove(inst) + + def test__create_buffer_large(self): + from waitress.buffers import TempfileBasedBuffer + + inst = self._makeOne() + inst.strbuf = b"x" * 11 + inst._create_buffer() + self.assertEqual(inst.buf.__class__, TempfileBasedBuffer) + self.assertEqual(inst.buf.get(100), b"x" * 11) + self.assertEqual(inst.strbuf, b"") + + def test__create_buffer_small(self): + from waitress.buffers import BytesIOBasedBuffer + + inst = self._makeOne() + inst.strbuf = b"x" * 5 + inst._create_buffer() + self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer) + self.assertEqual(inst.buf.get(100), b"x" * 5) + self.assertEqual(inst.strbuf, b"") + + def test_append_with_len_more_than_max_int(self): + from waitress.compat import MAXINT + + inst = self._makeOne() + inst.overflowed = True + buf = DummyBuffer(length=MAXINT) + inst.buf = buf + result = inst.append(b"x") + # we don't want this to throw an OverflowError on Python 2 (see + # https://github.com/Pylons/waitress/issues/47) + self.assertEqual(result, None) + self.buffers_to_close.remove(inst) + + def test_append_buf_None_not_longer_than_srtbuf_limit(self): + inst = self._makeOne() + inst.strbuf = b"x" * 5 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"xxxxxhello") + + def test_append_buf_None_longer_than_strbuf_limit(self): + inst = self._makeOne(10000) + inst.strbuf = b"x" * 8192 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"") + self.assertEqual(len(inst.buf), 8197) + + def test_append_overflow(self): + inst = self._makeOne(10) + inst.strbuf = b"x" * 8192 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"") + self.assertEqual(len(inst.buf), 8197) + + def test_append_sz_gt_overflow(self): + from waitress.buffers import BytesIOBasedBuffer + + f = io.BytesIO(b"data") + inst = self._makeOne(f) + buf = BytesIOBasedBuffer() + inst.buf = buf + inst.overflow = 2 + inst.append(b"data2") + self.assertEqual(f.getvalue(), b"data") + self.assertTrue(inst.overflowed) + self.assertNotEqual(inst.buf, buf) + + def test_get_buf_None_skip_False(self): + inst = self._makeOne() + inst.strbuf = b"x" * 5 + r = inst.get(5) + self.assertEqual(r, b"xxxxx") + + def test_get_buf_None_skip_True(self): + inst = self._makeOne() + inst.strbuf = b"x" * 5 + r = inst.get(5, skip=True) + self.assertFalse(inst.buf is None) + self.assertEqual(r, b"xxxxx") + + def test_skip_buf_None(self): + inst = self._makeOne() + inst.strbuf = b"data" + inst.skip(4) + self.assertEqual(inst.strbuf, b"") + self.assertNotEqual(inst.buf, None) + + def test_skip_buf_None_allow_prune_True(self): + inst = self._makeOne() + inst.strbuf = b"data" + inst.skip(4, True) + self.assertEqual(inst.strbuf, b"") + self.assertEqual(inst.buf, None) + + def test_prune_buf_None(self): + inst = self._makeOne() + inst.prune() + self.assertEqual(inst.strbuf, b"") + + def test_prune_with_buf(self): + inst = self._makeOne() + + class Buf(object): + def prune(self): + self.pruned = True + + inst.buf = Buf() + inst.prune() + self.assertEqual(inst.buf.pruned, True) + self.buffers_to_close.remove(inst) + + def test_prune_with_buf_overflow(self): + inst = self._makeOne() + + class DummyBuffer(io.BytesIO): + def getfile(self): + return self + + def prune(self): + return True + + def __len__(self): + return 5 + + def close(self): + pass + + buf = DummyBuffer(b"data") + inst.buf = buf + inst.overflowed = True + inst.overflow = 10 + inst.prune() + self.assertNotEqual(inst.buf, buf) + + def test_prune_with_buflen_more_than_max_int(self): + from waitress.compat import MAXINT + + inst = self._makeOne() + inst.overflowed = True + buf = DummyBuffer(length=MAXINT + 1) + inst.buf = buf + result = inst.prune() + # we don't want this to throw an OverflowError on Python 2 (see + # https://github.com/Pylons/waitress/issues/47) + self.assertEqual(result, None) + + def test_getfile_buf_None(self): + inst = self._makeOne() + f = inst.getfile() + self.assertTrue(hasattr(f, "read")) + + def test_getfile_buf_not_None(self): + inst = self._makeOne() + buf = io.BytesIO() + buf.getfile = lambda *x: buf + inst.buf = buf + f = inst.getfile() + self.assertEqual(f, buf) + + def test_close_nobuf(self): + inst = self._makeOne() + inst.buf = None + self.assertEqual(inst.close(), None) # doesnt raise + self.buffers_to_close.remove(inst) + + def test_close_withbuf(self): + class Buffer(object): + def close(self): + self.closed = True + + buf = Buffer() + inst = self._makeOne() + inst.buf = buf + inst.close() + self.assertTrue(buf.closed) + self.buffers_to_close.remove(inst) + + +class KindaFilelike(object): + def __init__(self, bytes, close=None, tellresults=None): + self.bytes = bytes + self.tellresults = tellresults + if close is not None: + self.close = lambda: close + + +class Filelike(KindaFilelike): + def seek(self, v, whence=0): + self.seeked = v + + def tell(self): + v = self.tellresults.pop(0) + return v + + +class DummyBuffer(object): + def __init__(self, length=0): + self.length = length + + def __len__(self): + return self.length + + def append(self, s): + self.length = self.length + len(s) + + def prune(self): + pass + + def close(self): + pass diff --git a/tests/test_channel.py b/tests/test_channel.py new file mode 100644 index 0000000..14ef5a0 --- /dev/null +++ b/tests/test_channel.py @@ -0,0 +1,882 @@ +import unittest +import io + + +class TestHTTPChannel(unittest.TestCase): + def _makeOne(self, sock, addr, adj, map=None): + from waitress.channel import HTTPChannel + + server = DummyServer() + return HTTPChannel(server, sock, addr, adj=adj, map=map) + + def _makeOneWithMap(self, adj=None): + if adj is None: + adj = DummyAdjustments() + sock = DummySock() + map = {} + inst = self._makeOne(sock, "127.0.0.1", adj, map=map) + inst.outbuf_lock = DummyLock() + return inst, sock, map + + def test_ctor(self): + inst, _, map = self._makeOneWithMap() + self.assertEqual(inst.addr, "127.0.0.1") + self.assertEqual(inst.sendbuf_len, 2048) + self.assertEqual(map[100], inst) + + def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self): + from waitress.compat import MAXINT + + inst, _, map = self._makeOneWithMap() + + class DummyBuffer(object): + chunks = [] + + def append(self, data): + self.chunks.append(data) + + class DummyData(object): + def __len__(self): + return MAXINT + + inst.total_outbufs_len = 1 + inst.outbufs = [DummyBuffer()] + inst.write_soon(DummyData()) + # we are testing that this method does not raise an OverflowError + # (see https://github.com/Pylons/waitress/issues/47) + self.assertEqual(inst.total_outbufs_len, MAXINT + 1) + + def test_writable_something_in_outbuf(self): + inst, sock, map = self._makeOneWithMap() + inst.total_outbufs_len = 3 + self.assertTrue(inst.writable()) + + def test_writable_nothing_in_outbuf(self): + inst, sock, map = self._makeOneWithMap() + self.assertFalse(inst.writable()) + + def test_writable_nothing_in_outbuf_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.will_close = True + self.assertTrue(inst.writable()) + + def test_handle_write_not_connected(self): + inst, sock, map = self._makeOneWithMap() + inst.connected = False + self.assertFalse(inst.handle_write()) + + def test_handle_write_with_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + + def test_handle_write_no_request_with_outbuf(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertNotEqual(inst.last_activity, 0) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_outbuf_raises_socketerror(self): + import socket + + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + outbuf = DummyBuffer(b"abc", socket.error) + inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) + inst.last_activity = 0 + inst.logger = DummyLogger() + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + self.assertEqual(sock.sent, b"") + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(outbuf.closed) + + def test_handle_write_outbuf_raises_othererror(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + outbuf = DummyBuffer(b"abc", IOError) + inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) + inst.last_activity = 0 + inst.logger = DummyLogger() + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + self.assertEqual(sock.sent, b"") + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(outbuf.closed) + + def test_handle_write_no_requests_no_outbuf_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + outbuf = DummyBuffer(b"") + inst.outbufs = [outbuf] + inst.will_close = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.connected, False) + self.assertEqual(sock.closed, True) + self.assertEqual(inst.last_activity, 0) + self.assertTrue(outbuf.closed) + + def test_handle_write_no_requests_outbuf_gt_send_bytes(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 2 + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_close_when_flushed(self): + inst, sock, map = self._makeOneWithMap() + outbuf = DummyBuffer(b"abc") + inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) + inst.will_close = False + inst.close_when_flushed = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, True) + self.assertEqual(inst.close_when_flushed, False) + self.assertEqual(sock.sent, b"abc") + self.assertTrue(outbuf.closed) + + def test_readable_no_requests_not_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.will_close = False + self.assertEqual(inst.readable(), True) + + def test_readable_no_requests_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.will_close = True + self.assertEqual(inst.readable(), False) + + def test_readable_with_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = True + self.assertEqual(inst.readable(), False) + + def test_handle_read_no_error(self): + inst, sock, map = self._makeOneWithMap() + inst.will_close = False + inst.recv = lambda *arg: b"abc" + inst.last_activity = 0 + L = [] + inst.received = lambda x: L.append(x) + result = inst.handle_read() + self.assertEqual(result, None) + self.assertNotEqual(inst.last_activity, 0) + self.assertEqual(L, [b"abc"]) + + def test_handle_read_error(self): + import socket + + inst, sock, map = self._makeOneWithMap() + inst.will_close = False + + def recv(b): + raise socket.error + + inst.recv = recv + inst.last_activity = 0 + inst.logger = DummyLogger() + result = inst.handle_read() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + self.assertEqual(len(inst.logger.exceptions), 1) + + def test_write_soon_empty_byte(self): + inst, sock, map = self._makeOneWithMap() + wrote = inst.write_soon(b"") + self.assertEqual(wrote, 0) + self.assertEqual(len(inst.outbufs[0]), 0) + + def test_write_soon_nonempty_byte(self): + inst, sock, map = self._makeOneWithMap() + wrote = inst.write_soon(b"a") + self.assertEqual(wrote, 1) + self.assertEqual(len(inst.outbufs[0]), 1) + + def test_write_soon_filewrapper(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + wrapper = ReadOnlyFileBasedBuffer(f, 8192) + wrapper.prepare() + inst, sock, map = self._makeOneWithMap() + outbufs = inst.outbufs + orig_outbuf = outbufs[0] + wrote = inst.write_soon(wrapper) + self.assertEqual(wrote, 3) + self.assertEqual(len(outbufs), 3) + self.assertEqual(outbufs[0], orig_outbuf) + self.assertEqual(outbufs[1], wrapper) + self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer") + + def test_write_soon_disconnected(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.connected = False + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) + + def test_write_soon_disconnected_while_over_watermark(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + + def dummy_flush(): + inst.connected = False + + inst._flush_outbufs_below_high_watermark = dummy_flush + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) + + def test_write_soon_rotates_outbuf_on_overflow(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.outbuf_high_watermark = 3 + inst.current_outbuf_count = 4 + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + self.assertEqual(len(inst.outbufs), 2) + self.assertEqual(inst.outbufs[0].get(), b"") + self.assertEqual(inst.outbufs[1].get(), b"xyz") + + def test_write_soon_waits_on_backpressure(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.outbuf_high_watermark = 3 + inst.total_outbufs_len = 4 + inst.current_outbuf_count = 4 + + class Lock(DummyLock): + def wait(self): + inst.total_outbufs_len = 0 + super(Lock, self).wait() + + inst.outbuf_lock = Lock() + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + self.assertEqual(len(inst.outbufs), 2) + self.assertEqual(inst.outbufs[0].get(), b"") + self.assertEqual(inst.outbufs[1].get(), b"xyz") + self.assertTrue(inst.outbuf_lock.waited) + + def test_handle_write_notify_after_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 1 + inst.adj.outbuf_high_watermark = 5 + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertTrue(inst.outbuf_lock.notified) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_no_notify_after_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 1 + inst.adj.outbuf_high_watermark = 2 + sock.send = lambda x: False + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertFalse(inst.outbuf_lock.notified) + self.assertEqual(sock.sent, b"") + + def test__flush_some_empty_outbuf(self): + inst, sock, map = self._makeOneWithMap() + result = inst._flush_some() + self.assertEqual(result, False) + + def test__flush_some_full_outbuf_socket_returns_nonzero(self): + inst, sock, map = self._makeOneWithMap() + inst.outbufs[0].append(b"abc") + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + result = inst._flush_some() + self.assertEqual(result, True) + + def test__flush_some_full_outbuf_socket_returns_zero(self): + inst, sock, map = self._makeOneWithMap() + sock.send = lambda x: False + inst.outbufs[0].append(b"abc") + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + result = inst._flush_some() + self.assertEqual(result, False) + + def test_flush_some_multiple_buffers_first_empty(self): + inst, sock, map = self._makeOneWithMap() + sock.send = lambda x: len(x) + buffer = DummyBuffer(b"abc") + inst.outbufs.append(buffer) + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + result = inst._flush_some() + self.assertEqual(result, True) + self.assertEqual(buffer.skipped, 3) + self.assertEqual(inst.outbufs, [buffer]) + + def test_flush_some_multiple_buffers_close_raises(self): + inst, sock, map = self._makeOneWithMap() + sock.send = lambda x: len(x) + buffer = DummyBuffer(b"abc") + inst.outbufs.append(buffer) + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + inst.logger = DummyLogger() + + def doraise(): + raise NotImplementedError + + inst.outbufs[0].close = doraise + result = inst._flush_some() + self.assertEqual(result, True) + self.assertEqual(buffer.skipped, 3) + self.assertEqual(inst.outbufs, [buffer]) + self.assertEqual(len(inst.logger.exceptions), 1) + + def test__flush_some_outbuf_len_gt_sys_maxint(self): + from waitress.compat import MAXINT + + inst, sock, map = self._makeOneWithMap() + + class DummyHugeOutbuffer(object): + def __init__(self): + self.length = MAXINT + 1 + + def __len__(self): + return self.length + + def get(self, numbytes): + self.length = 0 + return b"123" + + buf = DummyHugeOutbuffer() + inst.outbufs = [buf] + inst.send = lambda *arg: 0 + result = inst._flush_some() + # we are testing that _flush_some doesn't raise an OverflowError + # when one of its outbufs has a __len__ that returns gt sys.maxint + self.assertEqual(result, False) + + def test_handle_close(self): + inst, sock, map = self._makeOneWithMap() + inst.handle_close() + self.assertEqual(inst.connected, False) + self.assertEqual(sock.closed, True) + + def test_handle_close_outbuf_raises_on_close(self): + inst, sock, map = self._makeOneWithMap() + + def doraise(): + raise NotImplementedError + + inst.outbufs[0].close = doraise + inst.logger = DummyLogger() + inst.handle_close() + self.assertEqual(inst.connected, False) + self.assertEqual(sock.closed, True) + self.assertEqual(len(inst.logger.exceptions), 1) + + def test_add_channel(self): + inst, sock, map = self._makeOneWithMap() + fileno = inst._fileno + inst.add_channel(map) + self.assertEqual(map[fileno], inst) + self.assertEqual(inst.server.active_channels[fileno], inst) + + def test_del_channel(self): + inst, sock, map = self._makeOneWithMap() + fileno = inst._fileno + inst.server.active_channels[fileno] = True + inst.del_channel(map) + self.assertEqual(map.get(fileno), None) + self.assertEqual(inst.server.active_channels.get(fileno), None) + + def test_received(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.server.tasks, [inst]) + self.assertTrue(inst.requests) + + def test_received_no_chunk(self): + inst, sock, map = self._makeOneWithMap() + self.assertEqual(inst.received(b""), False) + + def test_received_preq_not_completed(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = False + preq.empty = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.requests, ()) + self.assertEqual(inst.server.tasks, []) + + def test_received_preq_completed_empty(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = True + preq.empty = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, None) + self.assertEqual(inst.server.tasks, []) + + def test_received_preq_error(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = True + preq.error = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, None) + self.assertEqual(len(inst.server.tasks), 1) + self.assertTrue(inst.requests) + + def test_received_preq_completed_connection_close(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = True + preq.empty = True + preq.connection_close = True + inst.received(b"GET / HTTP/1.1\r\n\r\n" + b"a" * 50000) + self.assertEqual(inst.request, None) + self.assertEqual(inst.server.tasks, []) + + def test_received_headers_finished_expect_continue_false(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.expect_continue = False + preq.headers_finished = True + preq.completed = False + preq.empty = False + preq.retval = 1 + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, preq) + self.assertEqual(inst.server.tasks, []) + self.assertEqual(inst.outbufs[0].get(100), b"") + + def test_received_headers_finished_expect_continue_true(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.expect_continue = True + preq.headers_finished = True + preq.completed = False + preq.empty = False + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, preq) + self.assertEqual(inst.server.tasks, []) + self.assertEqual(sock.sent, b"HTTP/1.1 100 Continue\r\n\r\n") + self.assertEqual(inst.sent_continue, True) + self.assertEqual(preq.completed, False) + + def test_received_headers_finished_expect_continue_true_sent_true(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.expect_continue = True + preq.headers_finished = True + preq.completed = False + preq.empty = False + inst.sent_continue = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, preq) + self.assertEqual(inst.server.tasks, []) + self.assertEqual(sock.sent, b"") + self.assertEqual(inst.sent_continue, True) + self.assertEqual(preq.completed, False) + + def test_service_no_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + + def test_service_with_one_request(self): + inst, sock, map = self._makeOneWithMap() + request = DummyRequest() + inst.task_class = DummyTaskClass() + inst.requests = [request] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request.serviced) + self.assertTrue(request.closed) + + def test_service_with_one_error_request(self): + inst, sock, map = self._makeOneWithMap() + request = DummyRequest() + request.error = DummyError() + inst.error_task_class = DummyTaskClass() + inst.requests = [request] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request.serviced) + self.assertTrue(request.closed) + + def test_service_with_multiple_requests(self): + inst, sock, map = self._makeOneWithMap() + request1 = DummyRequest() + request2 = DummyRequest() + inst.task_class = DummyTaskClass() + inst.requests = [request1, request2] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request1.serviced) + self.assertTrue(request2.serviced) + self.assertTrue(request1.closed) + self.assertTrue(request2.closed) + + def test_service_with_request_raises(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + + def test_service_with_requests_raises_already_wrote_header(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertTrue(inst.close_when_flushed) + self.assertEqual(inst.error_task_class.serviced, False) + self.assertTrue(request.closed) + + def test_service_with_requests_raises_didnt_write_header_expose_tbs(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = True + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertFalse(inst.will_close) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + + def test_service_with_requests_raises_didnt_write_header(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertTrue(inst.close_when_flushed) + self.assertTrue(request.closed) + + def test_service_with_request_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ClientDisconnected) + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.infos), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, False) + self.assertTrue(request.closed) + + def test_service_with_request_error_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + err_request = DummyRequest() + inst.requests = [request] + inst.parser_class = lambda x: err_request + inst.task_class = DummyTaskClass(RuntimeError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass(ClientDisconnected) + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertTrue(err_request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertEqual(len(inst.logger.infos), 0) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.task_class.serviced, True) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + + def test_cancel_no_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = () + inst.cancel() + self.assertEqual(inst.requests, []) + + def test_cancel_with_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [None] + inst.cancel() + self.assertEqual(inst.requests, []) + + +class DummySock(object): + blocking = False + closed = False + + def __init__(self): + self.sent = b"" + + def setblocking(self, *arg): + self.blocking = True + + def fileno(self): + return 100 + + def getpeername(self): + return "127.0.0.1" + + def getsockopt(self, level, option): + return 2048 + + def close(self): + self.closed = True + + def send(self, data): + self.sent += data + return len(data) + + +class DummyLock(object): + notified = False + + def __init__(self, acquirable=True): + self.acquirable = acquirable + + def acquire(self, val): + self.val = val + self.acquired = True + return self.acquirable + + def release(self): + self.released = True + + def notify(self): + self.notified = True + + def wait(self): + self.waited = True + + def __exit__(self, type, val, traceback): + self.acquire(True) + + def __enter__(self): + pass + + +class DummyBuffer(object): + closed = False + + def __init__(self, data, toraise=None): + self.data = data + self.toraise = toraise + + def get(self, *arg): + if self.toraise: + raise self.toraise + data = self.data + self.data = b"" + return data + + def skip(self, num, x): + self.skipped = num + + def __len__(self): + return len(self.data) + + def close(self): + self.closed = True + + +class DummyAdjustments(object): + outbuf_overflow = 1048576 + outbuf_high_watermark = 1048576 + inbuf_overflow = 512000 + cleanup_interval = 900 + url_scheme = "http" + channel_timeout = 300 + log_socket_errors = True + recv_bytes = 8192 + send_bytes = 1 + expose_tracebacks = True + ident = "waitress" + max_request_header_size = 10000 + + +class DummyServer(object): + trigger_pulled = False + adj = DummyAdjustments() + + def __init__(self): + self.tasks = [] + self.active_channels = {} + + def add_task(self, task): + self.tasks.append(task) + + def pull_trigger(self): + self.trigger_pulled = True + + +class DummyParser(object): + version = 1 + data = None + completed = True + empty = False + headers_finished = False + expect_continue = False + retval = None + error = None + connection_close = False + + def received(self, data): + self.data = data + if self.retval is not None: + return self.retval + return len(data) + + +class DummyRequest(object): + error = None + path = "/" + version = "1.0" + closed = False + + def __init__(self): + self.headers = {} + + def close(self): + self.closed = True + + +class DummyLogger(object): + def __init__(self): + self.exceptions = [] + self.infos = [] + self.warnings = [] + + def info(self, msg): + self.infos.append(msg) + + def exception(self, msg): + self.exceptions.append(msg) + + +class DummyError(object): + code = "431" + reason = "Bleh" + body = "My body" + + +class DummyTaskClass(object): + wrote_header = True + close_on_finish = False + serviced = False + + def __init__(self, toraise=None): + self.toraise = toraise + + def __call__(self, channel, request): + self.request = request + return self + + def service(self): + self.serviced = True + self.request.serviced = True + if self.toraise: + raise self.toraise diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 0000000..37c2193 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class Test_unquote_bytes_to_wsgi(unittest.TestCase): + def _callFUT(self, v): + from waitress.compat import unquote_bytes_to_wsgi + + return unquote_bytes_to_wsgi(v) + + def test_highorder(self): + from waitress.compat import PY3 + + val = b"/a%C5%9B" + result = self._callFUT(val) + if PY3: # pragma: no cover + # PEP 3333 urlunquoted-latin1-decoded-bytes + self.assertEqual(result, "/aÅ\x9b") + else: # pragma: no cover + # sanity + self.assertEqual(result, b"/a\xc5\x9b") diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100644 index 0000000..4b60676 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,1667 @@ +import errno +import logging +import multiprocessing +import os +import signal +import socket +import string +import subprocess +import sys +import time +import unittest +from waitress import server +from waitress.compat import httplib, tobytes +from waitress.utilities import cleanup_unix_socket + +dn = os.path.dirname +here = dn(__file__) + + +class NullHandler(logging.Handler): # pragma: no cover + """A logging handler that swallows all emitted messages. + """ + + def emit(self, record): + pass + + +def start_server(app, svr, queue, **kwargs): # pragma: no cover + """Run a fixture application. + """ + logging.getLogger("waitress").addHandler(NullHandler()) + try_register_coverage() + svr(app, queue, **kwargs).run() + + +def try_register_coverage(): # pragma: no cover + # Hack around multiprocessing exiting early and not triggering coverage's + # atexit handler by always registering a signal handler + + if "COVERAGE_PROCESS_START" in os.environ: + def sigterm(*args): + sys.exit(0) + + signal.signal(signal.SIGTERM, sigterm) + + +class FixtureTcpWSGIServer(server.TcpWSGIServer): + """A version of TcpWSGIServer that relays back what it's bound to. + """ + + family = socket.AF_INET # Testing + + def __init__(self, application, queue, **kw): # pragma: no cover + # Coverage doesn't see this as it's ran in a separate process. + kw["port"] = 0 # Bind to any available port. + super(FixtureTcpWSGIServer, self).__init__(application, **kw) + host, port = self.socket.getsockname() + if os.name == "nt": + host = "127.0.0.1" + queue.put((host, port)) + + +class SubprocessTests(object): + + # For nose: all tests may be ran in separate processes. + _multiprocess_can_split_ = True + + exe = sys.executable + + server = None + + def start_subprocess(self, target, **kw): + # Spawn a server process. + self.queue = multiprocessing.Queue() + + if "COVERAGE_RCFILE" in os.environ: + os.environ["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"] + + self.proc = multiprocessing.Process( + target=start_server, args=(target, self.server, self.queue), kwargs=kw, + ) + self.proc.start() + + if self.proc.exitcode is not None: # pragma: no cover + raise RuntimeError("%s didn't start" % str(target)) + # Get the socket the server is listening on. + self.bound_to = self.queue.get(timeout=5) + self.sock = self.create_socket() + + def stop_subprocess(self): + if self.proc.exitcode is None: + self.proc.terminate() + self.sock.close() + # This give us one FD back ... + self.queue.close() + self.proc.join() + + def assertline(self, line, status, reason, version): + v, s, r = (x.strip() for x in line.split(None, 2)) + self.assertEqual(s, tobytes(status)) + self.assertEqual(r, tobytes(reason)) + self.assertEqual(v, tobytes(version)) + + def create_socket(self): + return socket.socket(self.server.family, socket.SOCK_STREAM) + + def connect(self): + self.sock.connect(self.bound_to) + + def make_http_connection(self): + raise NotImplementedError # pragma: no cover + + def send_check_error(self, to_send): + self.sock.send(to_send) + + +class TcpTests(SubprocessTests): + + server = FixtureTcpWSGIServer + + def make_http_connection(self): + return httplib.HTTPConnection(*self.bound_to) + + +class SleepyThreadTests(TcpTests, unittest.TestCase): + # test that sleepy thread doesnt block other requests + + def setUp(self): + from tests.fixtureapps import sleepy + + self.start_subprocess(sleepy.app) + + def tearDown(self): + self.stop_subprocess() + + def test_it(self): + getline = os.path.join(here, "fixtureapps", "getline.py") + cmds = ( + [self.exe, getline, "http://%s:%d/sleepy" % self.bound_to], + [self.exe, getline, "http://%s:%d/" % self.bound_to], + ) + r, w = os.pipe() + procs = [] + for cmd in cmds: + procs.append(subprocess.Popen(cmd, stdout=w)) + time.sleep(3) + for proc in procs: + if proc.returncode is not None: # pragma: no cover + proc.terminate() + proc.wait() + # the notsleepy response should always be first returned (it sleeps + # for 2 seconds, then returns; the notsleepy response should be + # processed in the meantime) + result = os.read(r, 10000) + os.close(r) + os.close(w) + self.assertEqual(result, b"notsleepy returnedsleepy returned") + + +class EchoTests(object): + def setUp(self): + from tests.fixtureapps import echo + + self.start_subprocess( + echo.app, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-for", "x-forwarded-proto"}, + clear_untrusted_proxy_headers=True, + ) + + def tearDown(self): + self.stop_subprocess() + + def _read_echo(self, fp): + from tests.fixtureapps import echo + + line, headers, body = read_http(fp) + return line, headers, echo.parse_response(body) + + def test_date_and_server(self): + to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + + def test_bad_host_header(self): + # https://corte.si/posts/code/pathod/pythonservers/index.html + to_send = "GET / HTTP/1.0\r\n Host: 0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "400", "Bad Request", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + + def test_send_with_body(self): + to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" + to_send += "hello" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "5") + self.assertEqual(echo.body, b"hello") + + def test_send_empty_body(self): + to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "0") + self.assertEqual(echo.body, b"") + + def test_multiple_requests_with_body(self): + orig_sock = self.sock + for x in range(3): + self.sock = self.create_socket() + self.test_send_with_body() + self.sock.close() + self.sock = orig_sock + + def test_multiple_requests_without_body(self): + orig_sock = self.sock + for x in range(3): + self.sock = self.create_socket() + self.test_send_empty_body() + self.sock.close() + self.sock = orig_sock + + def test_without_crlf(self): + data = "Echo\r\nthis\r\nplease" + s = tobytes( + "GET / HTTP/1.0\r\n" + "Connection: close\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(int(echo.content_length), len(data)) + self.assertEqual(len(echo.body), len(data)) + self.assertEqual(echo.body, tobytes(data)) + + def test_large_body(self): + # 1024 characters. + body = "This string has 32 characters.\r\n" * 32 + s = tobytes( + "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(body), body) + ) + self.connect() + self.sock.send(s) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "1024") + self.assertEqual(echo.body, tobytes(body)) + + def test_many_clients(self): + conns = [] + for n in range(50): + h = self.make_http_connection() + h.request("GET", "/", headers={"Accept": "text/plain"}) + conns.append(h) + responses = [] + for h in conns: + response = h.getresponse() + self.assertEqual(response.status, 200) + responses.append(response) + for response in responses: + response.read() + for h in conns: + h.close() + + def test_chunking_request_without_content(self): + header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + self.connect() + self.sock.send(header) + self.sock.send(b"0\r\n\r\n") + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(echo.body, b"") + self.assertEqual(echo.content_length, "0") + self.assertFalse("transfer-encoding" in headers) + + def test_chunking_request_with_content(self): + control_line = b"20;\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + expected = s * 12 + header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + self.connect() + self.sock.send(header) + fp = self.sock.makefile("rb", 0) + for n in range(12): + self.sock.send(control_line) + self.sock.send(s) + self.sock.send(b"\r\n") # End the chunk + self.sock.send(b"0\r\n\r\n") + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(echo.body, expected) + self.assertEqual(echo.content_length, str(len(expected))) + self.assertFalse("transfer-encoding" in headers) + + def test_broken_chunked_encoding(self): + control_line = "20;\r\n" # 20 hex = 32 dec + s = "This string has 32 characters.\r\n" + to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + "\r\n" + # garbage in input + to_send += "garbage\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # receiver caught garbage and turned it into a 400 + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_broken_chunked_encoding_missing_chunk_end(self): + control_line = "20;\r\n" # 20 hex = 32 dec + s = "This string has 32 characters.\r\n" + to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + # garbage in input + to_send += "garbage" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # receiver caught garbage and turned it into a 400 + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(b"Chunk not properly terminated" in response_body) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_keepalive_http_10(self): + # Handling of Keep-Alive within HTTP 1.0 + data = "Default: Don't keep me alive" + s = tobytes( + "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + connection = response.getheader("Connection", "") + # We sent no Connection: Keep-Alive header + # Connection: close (or no header) is default. + self.assertTrue(connection != "Keep-Alive") + + def test_keepalive_http10_explicit(self): + # If header Connection: Keep-Alive is explicitly sent, + # we want to keept the connection open, we also need to return + # the corresponding header + data = "Keep me alive" + s = tobytes( + "GET / HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + connection = response.getheader("Connection", "") + self.assertEqual(connection, "Keep-Alive") + + def test_keepalive_http_11(self): + # Handling of Keep-Alive within HTTP 1.1 + + # All connections are kept alive, unless stated otherwise + data = "Default: Keep me alive" + s = tobytes( + "GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertTrue(response.getheader("connection") != "close") + + def test_keepalive_http11_explicit(self): + # Explicitly set keep-alive + data = "Default: Keep me alive" + s = tobytes( + "GET / HTTP/1.1\r\n" + "Connection: keep-alive\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertTrue(response.getheader("connection") != "close") + + def test_keepalive_http11_connclose(self): + # specifying Connection: close explicitly + data = "Don't keep me alive" + s = tobytes( + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertEqual(response.getheader("connection"), "close") + + def test_proxy_headers(self): + to_send = ( + "GET / HTTP/1.0\r\n" + "Content-Length: 0\r\n" + "Host: www.google.com:8080\r\n" + "X-Forwarded-For: 192.168.1.1\r\n" + "X-Forwarded-Proto: https\r\n" + "X-Forwarded-Port: 5000\r\n\r\n" + ) + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + self.assertIsNone(echo.headers.get("X_FORWARDED_PORT")) + self.assertEqual(echo.headers["HOST"], "www.google.com:8080") + self.assertEqual(echo.scheme, "https") + self.assertEqual(echo.remote_addr, "192.168.1.1") + self.assertEqual(echo.remote_host, "192.168.1.1") + + +class PipeliningTests(object): + def setUp(self): + from tests.fixtureapps import echo + + self.start_subprocess(echo.app_body_only) + + def tearDown(self): + self.stop_subprocess() + + def test_pipelining(self): + s = ( + "GET / HTTP/1.0\r\n" + "Connection: %s\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" + ) + to_send = b"" + count = 25 + for n in range(count): + body = "Response #%d\r\n" % (n + 1) + if n + 1 < count: + conn = "keep-alive" + else: + conn = "close" + to_send += tobytes(s % (conn, len(body), body)) + + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + for n in range(count): + expect_body = tobytes("Response #%d\r\n" % (n + 1)) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + length = int(headers.get("content-length")) or None + response_body = fp.read(length) + self.assertEqual(int(status), 200) + self.assertEqual(length, len(response_body)) + self.assertEqual(response_body, expect_body) + + +class ExpectContinueTests(object): + def setUp(self): + from tests.fixtureapps import echo + + self.start_subprocess(echo.app_body_only) + + def tearDown(self): + self.stop_subprocess() + + def test_expect_continue(self): + # specifying Connection: close explicitly + data = "I have expectations" + to_send = tobytes( + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "Content-Length: %d\r\n" + "Expect: 100-continue\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # continue status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + self.assertEqual(int(status), 100) + self.assertEqual(reason, b"Continue") + self.assertEqual(version, b"HTTP/1.1") + fp.readline() # blank line + line = fp.readline() # next status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + length = int(headers.get("content-length")) or None + response_body = fp.read(length) + self.assertEqual(int(status), 200) + self.assertEqual(length, len(response_body)) + self.assertEqual(response_body, tobytes(data)) + + +class BadContentLengthTests(object): + def setUp(self): + from tests.fixtureapps import badcl + + self.start_subprocess(badcl.app) + + def tearDown(self): + self.stop_subprocess() + + def test_short_body(self): + # check to see if server closes connection when body is too short + # for cl header + to_send = tobytes( + "GET /short_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + self.assertNotEqual(content_length, len(response_body)) + self.assertEqual(len(response_body), content_length - 1) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote closed connection (despite keepalive header); not sure why + # first send succeeds + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_long_body(self): + # check server doesnt close connection when body is too short + # for cl header + to_send = tobytes( + "GET /long_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) or None + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, tobytes("abcdefgh")) + # remote does not close connection (keepalive header) + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) or None + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + + +class NoContentLengthTests(object): + def setUp(self): + from tests.fixtureapps import nocl + + self.start_subprocess(nocl.app) + + def tearDown(self): + self.stop_subprocess() + + def test_http10_generator(self): + body = string.ascii_letters + to_send = ( + "GET / HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("content-length"), None) + self.assertEqual(headers.get("connection"), "close") + self.assertEqual(response_body, tobytes(body)) + # remote closed connection (despite keepalive header), because + # generators cannot have a content-length divined + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http10_list(self): + body = string.ascii_letters + to_send = ( + "GET /list HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers["content-length"], str(len(body))) + self.assertEqual(headers.get("connection"), "Keep-Alive") + self.assertEqual(response_body, tobytes(body)) + # remote keeps connection open because it divined the content length + # from a length-1 list + self.sock.send(to_send) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_http10_listlentwo(self): + body = string.ascii_letters + to_send = ( + "GET /list_lentwo HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("content-length"), None) + self.assertEqual(headers.get("connection"), "close") + self.assertEqual(response_body, tobytes(body)) + # remote closed connection (despite keepalive header), because + # lists of length > 1 cannot have their content length divined + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http11_generator(self): + body = string.ascii_letters + to_send = "GET / HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + expected = b"" + for chunk in chunks(body, 10): + expected += tobytes( + "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) + ) + expected += b"0\r\n\r\n" + self.assertEqual(response_body, expected) + # connection is always closed at the end of a chunked response + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http11_list(self): + body = string.ascii_letters + to_send = "GET /list HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(headers["content-length"], str(len(body))) + self.assertEqual(response_body, tobytes(body)) + # remote keeps connection open because it divined the content length + # from a length-1 list + self.sock.send(to_send) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + + def test_http11_listlentwo(self): + body = string.ascii_letters + to_send = "GET /list_lentwo HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + expected = b"" + for chunk in (body[0], body[1:]): + expected += tobytes( + "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) + ) + expected += b"0\r\n\r\n" + self.assertEqual(response_body, expected) + # connection is always closed at the end of a chunked response + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class WriteCallbackTests(object): + def setUp(self): + from tests.fixtureapps import writecb + + self.start_subprocess(writecb.app) + + def tearDown(self): + self.stop_subprocess() + + def test_short_body(self): + # check to see if server closes connection when body is too short + # for cl header + to_send = tobytes( + "GET /short_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (5) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, 9) + self.assertNotEqual(cl, len(response_body)) + self.assertEqual(len(response_body), cl - 1) + self.assertEqual(response_body, tobytes("abcdefgh")) + # remote closed connection (despite keepalive header) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_long_body(self): + # check server doesnt close connection when body is too long + # for cl header + to_send = tobytes( + "GET /long_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + content_length = int(headers.get("content-length")) or None + self.assertEqual(content_length, 9) + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote does not close connection (keepalive header) + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_equal_body(self): + # check server doesnt close connection when body is equal to + # cl header + to_send = tobytes( + "GET /equal_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + content_length = int(headers.get("content-length")) or None + self.assertEqual(content_length, 9) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote does not close connection (keepalive header) + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_no_content_length(self): + # wtf happens when there's no content-length + to_send = tobytes( + "GET /no_content_length HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + line, headers, response_body = read_http(fp) + content_length = headers.get("content-length") + self.assertEqual(content_length, None) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote closed connection (despite keepalive header) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class TooLargeTests(object): + + toobig = 1050 + + def setUp(self): + from tests.fixtureapps import toolarge + + self.start_subprocess( + toolarge.app, max_request_header_size=1000, max_request_body_size=1000 + ) + + def tearDown(self): + self.stop_subprocess() + + def test_request_body_too_large_with_wrong_cl_http10(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # server trusts the content-length header; no pipelining, + # so request fulfilled, extra bytes are thrown away + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http10(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # extra bytes are thrown away (no pipelining), connection closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http10_keepalive(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (assumed zero) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + line, headers, response_body = read_http(fp) + # next response overruns because the extra data appears to be + # header data + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http11(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # second response is an error response + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http11_connclose(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\nConnection: close\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (5) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http11(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # server trusts the content-length header (assumed 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # server assumes pipelined requests due to http/1.1, and the first + # request was assumed c-l 0 because it had no content-length header, + # so entire body looks like the header of the subsequent request + # second response is an error response + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http11_connclose(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\nConnection: close\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (assumed 0) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_chunked_encoding(self): + control_line = "20;\r\n" # 20 hex = 32 dec + s = "This string has 32 characters.\r\n" + to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + repeat = control_line + s + to_send += repeat * ((self.toobig // len(repeat)) + 1) + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # body bytes counter caught a max_request_body_size overrun + self.assertline(line, "413", "Request Entity Too Large", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class InternalServerErrorTests(object): + def setUp(self): + from tests.fixtureapps import error + + self.start_subprocess(error.app, expose_tracebacks=True) + + def tearDown(self): + self.stop_subprocess() + + def test_before_start_response_http_10(self): + to_send = "GET /before_start_response HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_before_start_response_http_11(self): + to_send = "GET /before_start_response HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_before_start_response_http_11_close(self): + to_send = tobytes( + "GET /before_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http10(self): + to_send = "GET /after_start_response HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http11(self): + to_send = "GET /after_start_response HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http11_close(self): + to_send = tobytes( + "GET /after_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_write_cb(self): + to_send = "GET /after_write_cb HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(response_body, b"") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_in_generator(self): + to_send = "GET /in_generator HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(response_body, b"") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class FileWrapperTests(object): + def setUp(self): + from tests.fixtureapps import filewrapper + + self.start_subprocess(filewrapper.app) + + def tearDown(self): + self.stop_subprocess() + + def test_filelike_http11(self): + to_send = "GET /filelike HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_filelike_nocl_http11(self): + to_send = "GET /filelike_nocl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_filelike_shortcl_http11(self): + to_send = "GET /filelike_shortcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, 1) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377" in response_body) + # connection has not been closed + + def test_filelike_longcl_http11(self): + to_send = "GET /filelike_longcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_http11(self): + to_send = "GET /notfilelike HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_iobase_http11(self): + to_send = "GET /notfilelike_iobase HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_nocl_http11(self): + to_send = "GET /notfilelike_nocl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed (no content-length) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_shortcl_http11(self): + to_send = "GET /notfilelike_shortcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, 1) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377" in response_body) + # connection has not been closed + + def test_notfilelike_longcl_http11(self): + to_send = "GET /notfilelike_longcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body) + 10) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_filelike_http10(self): + to_send = "GET /filelike HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_filelike_nocl_http10(self): + to_send = "GET /filelike_nocl HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_http10(self): + to_send = "GET /notfilelike HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_nocl_http10(self): + to_send = "GET /notfilelike_nocl HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed (no content-length) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase): + pass + + +class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase): + pass + + +class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase): + pass + + +class TcpBadContentLengthTests(BadContentLengthTests, TcpTests, unittest.TestCase): + pass + + +class TcpNoContentLengthTests(NoContentLengthTests, TcpTests, unittest.TestCase): + pass + + +class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase): + pass + + +class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase): + pass + + +class TcpInternalServerErrorTests( + InternalServerErrorTests, TcpTests, unittest.TestCase +): + pass + + +class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase): + pass + + +if hasattr(socket, "AF_UNIX"): + + class FixtureUnixWSGIServer(server.UnixWSGIServer): + """A version of UnixWSGIServer that relays back what it's bound to. + """ + + family = socket.AF_UNIX # Testing + + def __init__(self, application, queue, **kw): # pragma: no cover + # Coverage doesn't see this as it's ran in a separate process. + # To permit parallel testing, use a PID-dependent socket. + kw["unix_socket"] = "/tmp/waitress.test-%d.sock" % os.getpid() + super(FixtureUnixWSGIServer, self).__init__(application, **kw) + queue.put(self.socket.getsockname()) + + class UnixTests(SubprocessTests): + + server = FixtureUnixWSGIServer + + def make_http_connection(self): + return UnixHTTPConnection(self.bound_to) + + def stop_subprocess(self): + super(UnixTests, self).stop_subprocess() + cleanup_unix_socket(self.bound_to) + + def send_check_error(self, to_send): + # Unlike inet domain sockets, Unix domain sockets can trigger a + # 'Broken pipe' error when the socket it closed. + try: + self.sock.send(to_send) + except socket.error as exc: + self.assertEqual(get_errno(exc), errno.EPIPE) + + class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase): + pass + + class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase): + pass + + class UnixExpectContinueTests(ExpectContinueTests, UnixTests, unittest.TestCase): + pass + + class UnixBadContentLengthTests( + BadContentLengthTests, UnixTests, unittest.TestCase + ): + pass + + class UnixNoContentLengthTests(NoContentLengthTests, UnixTests, unittest.TestCase): + pass + + class UnixWriteCallbackTests(WriteCallbackTests, UnixTests, unittest.TestCase): + pass + + class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase): + pass + + class UnixInternalServerErrorTests( + InternalServerErrorTests, UnixTests, unittest.TestCase + ): + pass + + class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase): + pass + + +def parse_headers(fp): + """Parses only RFC2822 headers from a file pointer. + """ + headers = {} + while True: + line = fp.readline() + if line in (b"\r\n", b"\n", b""): + break + line = line.decode("iso-8859-1") + name, value = line.strip().split(":", 1) + headers[name.lower().strip()] = value.lower().strip() + return headers + + +class UnixHTTPConnection(httplib.HTTPConnection): + """Patched version of HTTPConnection that uses Unix domain sockets. + """ + + def __init__(self, path): + httplib.HTTPConnection.__init__(self, "localhost") + self.path = path + + def connect(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(self.path) + self.sock = sock + + +class ConnectionClosed(Exception): + pass + + +# stolen from gevent +def read_http(fp): # pragma: no cover + try: + response_line = fp.readline() + except socket.error as exc: + fp.close() + # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET + if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054): + raise ConnectionClosed + raise + if not response_line: + raise ConnectionClosed + + header_lines = [] + while True: + line = fp.readline() + if line in (b"\r\n", b"\r\n", b""): + break + else: + header_lines.append(line) + headers = dict() + for x in header_lines: + x = x.strip() + if not x: + continue + key, value = x.split(b": ", 1) + key = key.decode("iso-8859-1").lower() + value = value.decode("iso-8859-1") + assert key not in headers, "%s header duplicated" % key + headers[key] = value + + if "content-length" in headers: + num = int(headers["content-length"]) + body = b"" + left = num + while left > 0: + data = fp.read(left) + if not data: + break + body += data + left -= len(data) + else: + # read until EOF + body = fp.read() + + return response_line, headers, body + + +# stolen from gevent +def get_errno(exc): # pragma: no cover + """ Get the error code out of socket.error objects. + socket.error in <2.5 does not have errno attribute + socket.error in 3.x does not allow indexing access + e.args[0] works for all. + There are cases when args[0] is not errno. + i.e. http://bugs.python.org/issue6471 + Maybe there are cases when errno is set, but it is not the first argument? + """ + try: + if exc.errno is not None: + return exc.errno + except AttributeError: + pass + try: + return exc.args[0] + except IndexError: + return None + + +def chunks(l, n): + """ Yield successive n-sized chunks from l. + """ + for i in range(0, len(l), n): + yield l[i : i + n] diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..f9b91d7 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,51 @@ +import unittest + + +class Test_serve(unittest.TestCase): + def _callFUT(self, app, **kw): + from waitress import serve + + return serve(app, **kw) + + def test_it(self): + server = DummyServerFactory() + app = object() + result = self._callFUT(app, _server=server, _quiet=True) + self.assertEqual(server.app, app) + self.assertEqual(result, None) + self.assertEqual(server.ran, True) + + +class Test_serve_paste(unittest.TestCase): + def _callFUT(self, app, **kw): + from waitress import serve_paste + + return serve_paste(app, None, **kw) + + def test_it(self): + server = DummyServerFactory() + app = object() + result = self._callFUT(app, _server=server, _quiet=True) + self.assertEqual(server.app, app) + self.assertEqual(result, 0) + self.assertEqual(server.ran, True) + + +class DummyServerFactory(object): + ran = False + + def __call__(self, app, **kw): + self.adj = DummyAdj(kw) + self.app = app + self.kw = kw + return self + + def run(self): + self.ran = True + + +class DummyAdj(object): + verbose = False + + def __init__(self, kw): + self.__dict__.update(kw) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..91837c7 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,732 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser tests +""" +import unittest + +from waitress.compat import text_, tobytes + + +class TestHTTPRequestParser(unittest.TestCase): + def setUp(self): + from waitress.parser import HTTPRequestParser + from waitress.adjustments import Adjustments + + my_adj = Adjustments() + self.parser = HTTPRequestParser(my_adj) + + def test_get_body_stream_None(self): + self.parser.body_recv = None + result = self.parser.get_body_stream() + self.assertEqual(result.getvalue(), b"") + + def test_get_body_stream_nonNone(self): + body_rcv = DummyBodyStream() + self.parser.body_rcv = body_rcv + result = self.parser.get_body_stream() + self.assertEqual(result, body_rcv) + + def test_received_get_no_headers(self): + data = b"HTTP/1.0 GET /foobar\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 24) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_bad_host_header(self): + from waitress.utilities import BadRequest + + data = b"HTTP/1.0 GET /foobar\r\n Host: foo\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 36) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.error.__class__, BadRequest) + + def test_received_bad_transfer_encoding(self): + from waitress.utilities import ServerNotImplemented + + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: foo\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + result = self.parser.received(data) + self.assertEqual(result, 48) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.error.__class__, ServerNotImplemented) + + def test_received_nonsense_nothing(self): + data = b"\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 4) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_no_doublecr(self): + data = b"GET /foobar HTTP/8.4\r\n" + result = self.parser.received(data) + self.assertEqual(result, 22) + self.assertFalse(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_already_completed(self): + self.parser.completed = True + result = self.parser.received(b"a") + self.assertEqual(result, 0) + + def test_received_cl_too_large(self): + from waitress.utilities import RequestEntityTooLarge + + self.parser.adj.max_request_body_size = 2 + data = b"GET /foobar HTTP/8.4\r\nContent-Length: 10\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 44) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) + + def test_received_headers_too_large(self): + from waitress.utilities import RequestHeaderFieldsTooLarge + + self.parser.adj.max_request_header_size = 2 + data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 34) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestHeaderFieldsTooLarge)) + + def test_received_body_too_large(self): + from waitress.utilities import RequestEntityTooLarge + + self.parser.adj.max_request_body_size = 2 + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + + result = self.parser.received(data) + self.assertEqual(result, 62) + self.parser.received(data[result:]) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) + + def test_received_error_from_parser(self): + from waitress.utilities import BadRequest + + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"garbage\r\n" + ) + # header + result = self.parser.received(data) + # body + result = self.parser.received(data[result:]) + self.assertEqual(result, 9) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, BadRequest)) + + def test_received_chunked_completed_sets_content_length(self): + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + result = self.parser.received(data) + self.assertEqual(result, 62) + data = data[result:] + result = self.parser.received(data) + self.assertTrue(self.parser.completed) + self.assertTrue(self.parser.error is None) + self.assertEqual(self.parser.headers["CONTENT_LENGTH"], "29") + + def test_parse_header_gardenpath(self): + data = b"GET /foobar HTTP/8.4\r\nfoo: bar\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.first_line, b"GET /foobar HTTP/8.4") + self.assertEqual(self.parser.headers["FOO"], "bar") + + def test_parse_header_no_cr_in_headerplus(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4" + + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_bad_content_length(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\ncontent-length: abc\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_multiple_content_length(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\ncontent-length: 10\r\ncontent-length: 20\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_11_te_chunked(self): + # NB: test that capitalization of header value is unimportant + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: ChUnKed\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.body_rcv.__class__.__name__, "ChunkedReceiver") + + def test_parse_header_transfer_encoding_invalid(self): + from waitress.parser import TransferEncodingNotImplemented + + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_multiple(self): + from waitress.parser import TransferEncodingNotImplemented + + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\ntransfer-encoding: chunked\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_whitespace(self): + from waitress.parser import TransferEncodingNotImplemented + + data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding:\x85chunked\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_unicode(self): + from waitress.parser import TransferEncodingNotImplemented + + # This is the binary encoding for the UTF-8 character + # https://www.compart.com/en/unicode/U+212A "unicode character "K"" + # which if waitress were to accidentally do the wrong thing get + # lowercased to just the ascii "k" due to unicode collisions during + # transformation + data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding: chun\xe2\x84\xaaed\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_11_expect_continue(self): + data = b"GET /foobar HTTP/1.1\r\nexpect: 100-continue\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.expect_continue, True) + + def test_parse_header_connection_close(self): + data = b"GET /foobar HTTP/1.1\r\nConnection: close\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.connection_close, True) + + def test_close_with_body_rcv(self): + body_rcv = DummyBodyStream() + self.parser.body_rcv = body_rcv + self.parser.close() + self.assertTrue(body_rcv.closed) + + def test_close_with_no_body_rcv(self): + self.parser.body_rcv = None + self.parser.close() # doesn't raise + + def test_parse_header_lf_only(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\nfoo: bar" + + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_cr_only(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\rfoo: bar" + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_extra_lf_in_header(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\nfoo: \nbar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Bare CR or LF found in header line", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_extra_lf_in_first_line(self): + from waitress.parser import ParsingError + + data = b"GET /foobar\n HTTP/8.4\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Bare CR or LF found in HTTP message", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_whitespace(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\nfoo : bar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_whitespace_vtab(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo:\x0bbar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_no_colon(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nnotvalid\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_folding_spacing(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\n\t\x0bbaz\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_chars(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: \x0bbaz\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_empty(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nempty:\r\n" + self.parser.parse_header(data) + + self.assertIn("EMPTY", self.parser.headers) + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["EMPTY"], "") + self.assertEqual(self.parser.headers["FOO"], "bar") + + def test_parse_header_multiple_values(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever, more, please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_header_folded(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more, please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_header_folded_multiple(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more\r\nfoo: please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_extra_space(self): + # Tests errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: abrowser/0.001 (C O M M E N T)\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "abrowser/0.001 (C O M M E N T)") + + def test_parse_header_invalid_backtrack_bad(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\x10\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_short_values(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\none: 1\r\ntwo: 22\r\n" + self.parser.parse_header(data) + + self.assertIn("ONE", self.parser.headers) + self.assertIn("TWO", self.parser.headers) + self.assertEqual(self.parser.headers["ONE"], "1") + self.assertEqual(self.parser.headers["TWO"], "22") + + +class Test_split_uri(unittest.TestCase): + def _callFUT(self, uri): + from waitress.parser import split_uri + + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + + def test_split_uri_unquoting_unneeded(self): + self._callFUT(b"http://localhost:8080/abc def") + self.assertEqual(self.path, "/abc def") + + def test_split_uri_unquoting_needed(self): + self._callFUT(b"http://localhost:8080/abc%20def") + self.assertEqual(self.path, "/abc def") + + def test_split_url_with_query(self): + self._callFUT(b"http://localhost:8080/abc?a=1&b=2") + self.assertEqual(self.path, "/abc") + self.assertEqual(self.query, "a=1&b=2") + + def test_split_url_with_query_empty(self): + self._callFUT(b"http://localhost:8080/abc?") + self.assertEqual(self.path, "/abc") + self.assertEqual(self.query, "") + + def test_split_url_with_fragment(self): + self._callFUT(b"http://localhost:8080/#foo") + self.assertEqual(self.path, "/") + self.assertEqual(self.fragment, "foo") + + def test_split_url_https(self): + self._callFUT(b"https://localhost:8080/") + self.assertEqual(self.path, "/") + self.assertEqual(self.proxy_scheme, "https") + self.assertEqual(self.proxy_netloc, "localhost:8080") + + def test_split_uri_unicode_error_raises_parsing_error(self): + # See https://github.com/Pylons/waitress/issues/64 + from waitress.parser import ParsingError + + # Either pass or throw a ParsingError, just don't throw another type of + # exception as that will cause the connection to close badly: + try: + self._callFUT(b"/\xd0") + except ParsingError: + pass + + def test_split_uri_path(self): + self._callFUT(b"//testing/whatever") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "") + self.assertEqual(self.fragment, "") + + def test_split_uri_path_query(self): + self._callFUT(b"//testing/whatever?a=1&b=2") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "a=1&b=2") + self.assertEqual(self.fragment, "") + + def test_split_uri_path_query_fragment(self): + self._callFUT(b"//testing/whatever?a=1&b=2#fragment") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "a=1&b=2") + self.assertEqual(self.fragment, "fragment") + + +class Test_get_header_lines(unittest.TestCase): + def _callFUT(self, data): + from waitress.parser import get_header_lines + + return get_header_lines(data) + + def test_get_header_lines(self): + result = self._callFUT(b"slam\r\nslim") + self.assertEqual(result, [b"slam", b"slim"]) + + def test_get_header_lines_folded(self): + # From RFC2616: + # HTTP/1.1 header field values can be folded onto multiple lines if the + # continuation line begins with a space or horizontal tab. All linear + # white space, including folding, has the same semantics as SP. A + # recipient MAY replace any linear white space with a single SP before + # interpreting the field value or forwarding the message downstream. + + # We are just preserving the whitespace that indicates folding. + result = self._callFUT(b"slim\r\n slam") + self.assertEqual(result, [b"slim slam"]) + + def test_get_header_lines_tabbed(self): + result = self._callFUT(b"slam\r\n\tslim") + self.assertEqual(result, [b"slam\tslim"]) + + def test_get_header_lines_malformed(self): + # https://corte.si/posts/code/pathod/pythonservers/index.html + from waitress.parser import ParsingError + + self.assertRaises(ParsingError, self._callFUT, b" Host: localhost\r\n\r\n") + + +class Test_crack_first_line(unittest.TestCase): + def _callFUT(self, line): + from waitress.parser import crack_first_line + + return crack_first_line(line) + + def test_crack_first_line_matchok(self): + result = self._callFUT(b"GET / HTTP/1.0") + self.assertEqual(result, (b"GET", b"/", b"1.0")) + + def test_crack_first_line_lowercase_method(self): + from waitress.parser import ParsingError + + self.assertRaises(ParsingError, self._callFUT, b"get / HTTP/1.0") + + def test_crack_first_line_nomatch(self): + result = self._callFUT(b"GET / bleh") + self.assertEqual(result, (b"", b"", b"")) + + result = self._callFUT(b"GET /info?txtAirPlay&txtRAOP RTSP/1.0") + self.assertEqual(result, (b"", b"", b"")) + + def test_crack_first_line_missing_version(self): + result = self._callFUT(b"GET /") + self.assertEqual(result, (b"GET", b"/", b"")) + + +class TestHTTPRequestParserIntegration(unittest.TestCase): + def setUp(self): + from waitress.parser import HTTPRequestParser + from waitress.adjustments import Adjustments + + my_adj = Adjustments() + self.parser = HTTPRequestParser(my_adj) + + def feed(self, data): + parser = self.parser + + for n in range(100): # make sure we never loop forever + consumed = parser.received(data) + data = data[consumed:] + + if parser.completed: + return + raise ValueError("Looping") # pragma: no cover + + def testSimpleGET(self): + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"FirstName: mickey\r\n" + b"lastname: Mouse\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + parser = self.parser + self.feed(data) + self.assertTrue(parser.completed) + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual( + parser.headers, + {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "6",}, + ) + self.assertEqual(parser.path, "/foobar") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.query, "") + self.assertEqual(parser.proxy_scheme, "") + self.assertEqual(parser.proxy_netloc, "") + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") + + def testComplexGET(self): + data = ( + b"GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4\r\n" + b"FirstName: mickey\r\n" + b"lastname: Mouse\r\n" + b"content-length: 10\r\n" + b"\r\n" + b"Hello mickey." + ) + parser = self.parser + self.feed(data) + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual( + parser.headers, + {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "10"}, + ) + # path should be utf-8 encoded + self.assertEqual( + tobytes(parser.path).decode("utf-8"), + text_(b"/foo/a++/\xc3\xa4=&a:int", "utf-8"), + ) + self.assertEqual( + parser.query, "d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6" + ) + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello mick") + + def testProxyGET(self): + data = ( + b"GET https://example.com:8080/foobar HTTP/8.4\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + parser = self.parser + self.feed(data) + self.assertTrue(parser.completed) + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual(parser.headers, {"CONTENT_LENGTH": "6"}) + self.assertEqual(parser.path, "/foobar") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.proxy_scheme, "https") + self.assertEqual(parser.proxy_netloc, "example.com:8080") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.query, "") + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") + + def testDuplicateHeaders(self): + # Ensure that headers with the same key get concatenated as per + # RFC2616. + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"x-forwarded-for: 10.11.12.13\r\n" + b"x-forwarded-for: unknown,127.0.0.1\r\n" + b"X-Forwarded_for: 255.255.255.255\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual( + self.parser.headers, + { + "CONTENT_LENGTH": "6", + "X_FORWARDED_FOR": "10.11.12.13, unknown,127.0.0.1", + }, + ) + + def testSpoofedHeadersDropped(self): + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"x-auth_user: bob\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {"CONTENT_LENGTH": "6",}) + + +class DummyBodyStream(object): + def getfile(self): + return self + + def getbuf(self): + return self + + def close(self): + self.closed = True diff --git a/tests/test_proxy_headers.py b/tests/test_proxy_headers.py new file mode 100644 index 0000000..15b4a08 --- /dev/null +++ b/tests/test_proxy_headers.py @@ -0,0 +1,724 @@ +import unittest + +from waitress.compat import tobytes + + +class TestProxyHeadersMiddleware(unittest.TestCase): + def _makeOne(self, app, **kw): + from waitress.proxy_headers import proxy_headers_middleware + + return proxy_headers_middleware(app, **kw) + + def _callFUT(self, app, **kw): + response = DummyResponse() + environ = DummyEnviron(**kw) + + def start_response(status, response_headers): + response.status = status + response.headers = response_headers + + response.steps = list(app(environ, start_response)) + response.body = b"".join(tobytes(s) for s in response.steps) + return response + + def test_get_environment_values_w_scheme_override_untrusted(self): + inner = DummyApp() + app = self._makeOne(inner) + response = self._callFUT( + app, headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",} + ) + self.assertEqual(response.status, "200 OK") + self.assertEqual(inner.environ["wsgi.url_scheme"], "http") + + def test_get_environment_values_w_scheme_override_trusted(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"x-forwarded-proto"}, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 8080], + headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",}, + ) + + environ = inner.environ + self.assertEqual(response.status, "200 OK") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["REMOTE_ADDR"], "192.168.1.1") + self.assertEqual(environ["HTTP_X_FOO"], "BAR") + + def test_get_environment_values_w_bogus_scheme_override(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"x-forwarded-proto"}, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FOO": "BAR", + "X_FORWARDED_PROTO": "http://p02n3e.com?url=http", + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) + + def test_get_environment_warning_other_proxy_headers(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + log_untrusted=True, + logger=logger, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FORWARDED_FOR": "[2001:db8::1]", + "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + self.assertEqual(len(logger.logged), 1) + + environ = inner.environ + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_get_environment_contains_all_headers_including_untrusted(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=False, + ) + headers_orig = { + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + } + response = self._callFUT( + app, addr=["192.168.1.1", 80], headers=headers_orig.copy(), + ) + self.assertEqual(response.status, "200 OK") + environ = inner.environ + for k, expected in headers_orig.items(): + result = environ["HTTP_%s" % k] + self.assertEqual(result, expected) + + def test_get_environment_contains_only_trusted_headers(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=True, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["HTTP_X_FORWARDED_BY"], "Waitress") + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) + self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) + + def test_get_environment_clears_headers_if_untrusted_proxy(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=True, + ) + response = self._callFUT( + app, + addr=["192.168.1.255", 80], + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertNotIn("HTTP_X_FORWARDED_BY", environ) + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) + self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) + + def test_parse_proxy_headers_forwarded_for(self): + inner = DummyApp() + app = self._makeOne( + inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": "192.0.2.1"}) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "192.0.2.1") + + def test_parse_proxy_headers_forwarded_for_v6_missing_brackets(self): + inner = DummyApp() + app = self._makeOne( + inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": "2001:db8::0"}) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::0") + + def test_parse_proxy_headers_forwared_for_multiple(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT( + app, headers={"X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1"} + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + + def test_parse_forwarded_multiple_proxies_trust_only_two(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + "For=192.0.2.1;host=fake.com, " + "For=198.51.100.2;host=example.com:8080, " + "For=203.0.113.1" + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_forwarded_multiple_proxies(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + 'for="[2001:db8::1]:3821";host="example.com:8443";proto="https", ' + 'for=192.0.2.1;host="example.internal:8080"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") + self.assertEqual(environ["REMOTE_PORT"], "3821") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8443") + self.assertEqual(environ["SERVER_PORT"], "8443") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_multiple_proxies_minimal(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + 'for="[2001:db8::1]";proto="https", ' + 'for=192.0.2.1;host="example.org"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") + self.assertEqual(environ["SERVER_NAME"], "example.org") + self.assertEqual(environ["HTTP_HOST"], "example.org") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_proxy_headers_forwarded_host_with_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com:8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_without_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com") + self.assertEqual(environ["SERVER_PORT"], "80") + + def test_parse_proxy_headers_forwarded_host_with_forwarded_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com, example.org", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port_limit_one_trusted( + self, + ): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com, example.org", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "203.0.113.1") + self.assertEqual(environ["SERVER_NAME"], "example.org") + self.assertEqual(environ["HTTP_HOST"], "example.org:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_forwarded(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": "For=198.51.100.2:5858;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["REMOTE_PORT"], "5858") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_empty_pair(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, headers={"FORWARDED": "For=198.51.100.2;;proto=https;by=_unused",} + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + + def test_parse_forwarded_pair_token_whitespace(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, headers={"FORWARDED": "For=198.51.100.2; proto =https",} + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_pair_value_whitespace(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, headers={"FORWARDED": 'For= "198.51.100.2"; proto =https',} + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_pair_no_equals(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT(app, headers={"FORWARDED": "For"}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_warning_unknown_token(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + logger=logger, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + "For=198.51.100.2;host=example.com:8080;proto=https;" + 'unknown="yolo"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + self.assertEqual(len(logger.logged), 1) + self.assertIn("Unknown Forwarded token", logger.logged[0]) + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_no_valid_proxy_headers(self): + inner = DummyApp() + app = self._makeOne(inner, trusted_proxy="*", trusted_proxy_count=1,) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["HTTP_HOST"], "192.168.1.1:80") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "http") + + def test_parse_multiple_x_forwarded_proto(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-proto"}, + logger=logger, + ) + response = self._callFUT(app, headers={"X_FORWARDED_PROTO": "http, https",}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) + + def test_parse_multiple_x_forwarded_port(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-port"}, + logger=logger, + ) + response = self._callFUT(app, headers={"X_FORWARDED_PORT": "443, 80",}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Port" malformed', response.body) + + def test_parse_forwarded_port_wrong_proto_port_80(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-port", + "x-forwarded-host", + "x-forwarded-proto", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "80", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:80") + self.assertEqual(environ["SERVER_PORT"], "80") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_port_wrong_proto_port_443(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-port", + "x-forwarded-host", + "x-forwarded-proto", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "443", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:443") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["wsgi.url_scheme"], "http") + + def test_parse_forwarded_for_bad_quote(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": '"foo'}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-For" malformed', response.body) + + def test_parse_forwarded_host_bad_quote(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-host"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_HOST": '"foo'}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Host" malformed', response.body) + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, *args): + self.logged.append(msg % args) + + +class DummyApp(object): + def __call__(self, environ, start_response): + self.environ = environ + start_response("200 OK", [("Content-Type", "text/plain")]) + yield "hello" + + +class DummyResponse(object): + status = None + headers = None + body = None + + +def DummyEnviron( + addr=("127.0.0.1", 8080), scheme="http", server="localhost", headers=None, +): + environ = { + "REMOTE_ADDR": addr[0], + "REMOTE_HOST": addr[0], + "REMOTE_PORT": addr[1], + "SERVER_PORT": str(addr[1]), + "SERVER_NAME": server, + "wsgi.url_scheme": scheme, + "HTTP_HOST": "192.168.1.1:80", + } + if headers: + environ.update( + { + "HTTP_" + key.upper().replace("-", "_"): value + for key, value in headers.items() + } + ) + return environ diff --git a/tests/test_receiver.py b/tests/test_receiver.py new file mode 100644 index 0000000..b4910bb --- /dev/null +++ b/tests/test_receiver.py @@ -0,0 +1,242 @@ +import unittest + + +class TestFixedStreamReceiver(unittest.TestCase): + def _makeOne(self, cl, buf): + from waitress.receiver import FixedStreamReceiver + + return FixedStreamReceiver(cl, buf) + + def test_received_remain_lt_1(self): + buf = DummyBuffer() + inst = self._makeOne(0, buf) + result = inst.received("a") + self.assertEqual(result, 0) + self.assertEqual(inst.completed, True) + + def test_received_remain_lte_datalen(self): + buf = DummyBuffer() + inst = self._makeOne(1, buf) + result = inst.received("aa") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, True) + self.assertEqual(inst.completed, 1) + self.assertEqual(inst.remain, 0) + self.assertEqual(buf.data, ["a"]) + + def test_received_remain_gt_datalen(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + result = inst.received("aa") + self.assertEqual(result, 2) + self.assertEqual(inst.completed, False) + self.assertEqual(inst.remain, 8) + self.assertEqual(buf.data, ["aa"]) + + def test_getfile(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + self.assertEqual(inst.getfile(), buf) + + def test_getbuf(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + self.assertEqual(inst.getbuf(), buf) + + def test___len__(self): + buf = DummyBuffer(["1", "2"]) + inst = self._makeOne(10, buf) + self.assertEqual(inst.__len__(), 2) + + +class TestChunkedReceiver(unittest.TestCase): + def _makeOne(self, buf): + from waitress.receiver import ChunkedReceiver + + return ChunkedReceiver(buf) + + def test_alreadycompleted(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.completed = True + result = inst.received(b"a") + self.assertEqual(result, 0) + self.assertEqual(inst.completed, True) + + def test_received_remain_gt_zero(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.chunk_remainder = 100 + result = inst.received(b"a") + self.assertEqual(inst.chunk_remainder, 99) + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_control_line_notfinished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"a") + self.assertEqual(inst.control_line, b"a") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_control_line_finished_garbage_in_input(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"garbage\r\n") + self.assertEqual(result, 9) + self.assertTrue(inst.error) + + def test_received_control_line_finished_all_chunks_not_received(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"a;discard\r\n") + self.assertEqual(inst.control_line, b"") + self.assertEqual(inst.chunk_remainder, 10) + self.assertEqual(inst.all_chunks_received, False) + self.assertEqual(result, 11) + self.assertEqual(inst.completed, False) + + def test_received_control_line_finished_all_chunks_received(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"0;discard\r\n") + self.assertEqual(inst.control_line, b"") + self.assertEqual(inst.all_chunks_received, True) + self.assertEqual(result, 11) + self.assertEqual(inst.completed, False) + + def test_received_trailer_startswith_crlf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"\r\n") + self.assertEqual(result, 2) + self.assertEqual(inst.completed, True) + + def test_received_trailer_startswith_lf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"\n") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_trailer_not_finished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"a") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_trailer_finished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"abc\r\n\r\n") + self.assertEqual(inst.trailer, b"abc\r\n\r\n") + self.assertEqual(result, 7) + self.assertEqual(inst.completed, True) + + def test_getfile(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + self.assertEqual(inst.getfile(), buf) + + def test_getbuf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + self.assertEqual(inst.getbuf(), buf) + + def test___len__(self): + buf = DummyBuffer(["1", "2"]) + inst = self._makeOne(buf) + self.assertEqual(inst.__len__(), 2) + + def test_received_chunk_is_properly_terminated(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4\r\nWiki\r\n" + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, False) + self.assertEqual(buf.data[0], b"Wiki") + + def test_received_chunk_not_properly_terminated(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4\r\nWikibadchunk\r\n" + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, False) + self.assertEqual(buf.data[0], b"Wiki") + self.assertEqual(inst.error.__class__, BadRequest) + + def test_received_multiple_chunks(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = ( + b"4\r\n" + b"Wiki\r\n" + b"5\r\n" + b"pedia\r\n" + b"E\r\n" + b" in\r\n" + b"\r\n" + b"chunks.\r\n" + b"0\r\n" + b"\r\n" + ) + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, True) + self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") + self.assertEqual(inst.error, None) + + def test_received_multiple_chunks_split(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data1 = b"4\r\nWiki\r" + result = inst.received(data1) + self.assertEqual(result, len(data1)) + + data2 = ( + b"\n5\r\n" + b"pedia\r\n" + b"E\r\n" + b" in\r\n" + b"\r\n" + b"chunks.\r\n" + b"0\r\n" + b"\r\n" + ) + + result = inst.received(data2) + self.assertEqual(result, len(data2)) + + self.assertEqual(inst.completed, True) + self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") + self.assertEqual(inst.error, None) + + +class DummyBuffer(object): + def __init__(self, data=None): + if data is None: + data = [] + self.data = data + + def append(self, s): + self.data.append(s) + + def getfile(self): + return self + + def __len__(self): + return len(self.data) diff --git a/tests/test_regression.py b/tests/test_regression.py new file mode 100644 index 0000000..3c4c6c2 --- /dev/null +++ b/tests/test_regression.py @@ -0,0 +1,147 @@ +############################################################################## +# +# Copyright (c) 2005 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Tests for waitress.channel maintenance logic +""" +import doctest + + +class FakeSocket: # pragma: no cover + data = "" + setblocking = lambda *_: None + close = lambda *_: None + + def __init__(self, no): + self.no = no + + def fileno(self): + return self.no + + def getpeername(self): + return ("localhost", self.no) + + def send(self, data): + self.data += data + return len(data) + + def recv(self, data): + return "data" + + +def zombies_test(): + """Regression test for HTTPChannel.maintenance method + + Bug: This method checks for channels that have been "inactive" for a + configured time. The bug was that last_activity is set at creation time + but never updated during async channel activity (reads and writes), so + any channel older than the configured timeout will be closed when a new + channel is created, regardless of activity. + + >>> import time + >>> import waitress.adjustments + >>> config = waitress.adjustments.Adjustments() + + >>> from waitress.server import HTTPServer + >>> class TestServer(HTTPServer): + ... def bind(self, (ip, port)): + ... print "Listening on %s:%d" % (ip or '*', port) + >>> sb = TestServer('127.0.0.1', 80, start=False, verbose=True) + Listening on 127.0.0.1:80 + + First we confirm the correct behavior, where a channel with no activity + for the timeout duration gets closed. + + >>> from waitress.channel import HTTPChannel + >>> socket = FakeSocket(42) + >>> channel = HTTPChannel(sb, socket, ('localhost', 42)) + + >>> channel.connected + True + + >>> channel.last_activity -= int(config.channel_timeout) + 1 + + >>> channel.next_channel_cleanup[0] = channel.creation_time - int( + ... config.cleanup_interval) - 1 + + >>> socket2 = FakeSocket(7) + >>> channel2 = HTTPChannel(sb, socket2, ('localhost', 7)) + + >>> channel.connected + False + + Write Activity + -------------- + + Now we make sure that if there is activity the channel doesn't get closed + incorrectly. + + >>> channel2.connected + True + + >>> channel2.last_activity -= int(config.channel_timeout) + 1 + + >>> channel2.handle_write() + + >>> channel2.next_channel_cleanup[0] = channel2.creation_time - int( + ... config.cleanup_interval) - 1 + + >>> socket3 = FakeSocket(3) + >>> channel3 = HTTPChannel(sb, socket3, ('localhost', 3)) + + >>> channel2.connected + True + + Read Activity + -------------- + + We should test to see that read activity will update a channel as well. + + >>> channel3.connected + True + + >>> channel3.last_activity -= int(config.channel_timeout) + 1 + + >>> import waitress.parser + >>> channel3.parser_class = ( + ... waitress.parser.HTTPRequestParser) + >>> channel3.handle_read() + + >>> channel3.next_channel_cleanup[0] = channel3.creation_time - int( + ... config.cleanup_interval) - 1 + + >>> socket4 = FakeSocket(4) + >>> channel4 = HTTPChannel(sb, socket4, ('localhost', 4)) + + >>> channel3.connected + True + + Main loop window + ---------------- + + There is also a corner case we'll do a shallow test for where a + channel can be closed waiting for the main loop. + + >>> channel4.last_activity -= 1 + + >>> last_active = channel4.last_activity + + >>> channel4.set_async() + + >>> channel4.last_activity != last_active + True + +""" + + +def test_suite(): + return doctest.DocTestSuite() diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 0000000..e53018b --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,191 @@ +import contextlib +import os +import sys + +if sys.version_info[:2] == (2, 6): # pragma: no cover + import unittest2 as unittest +else: # pragma: no cover + import unittest + +from waitress import runner + + +class Test_match(unittest.TestCase): + def test_empty(self): + self.assertRaisesRegexp( + ValueError, "^Malformed application ''$", runner.match, "" + ) + + def test_module_only(self): + self.assertRaisesRegexp( + ValueError, r"^Malformed application 'foo\.bar'$", runner.match, "foo.bar" + ) + + def test_bad_module(self): + self.assertRaisesRegexp( + ValueError, + r"^Malformed application 'foo#bar:barney'$", + runner.match, + "foo#bar:barney", + ) + + def test_module_obj(self): + self.assertTupleEqual( + runner.match("foo.bar:fred.barney"), ("foo.bar", "fred.barney") + ) + + +class Test_resolve(unittest.TestCase): + def test_bad_module(self): + self.assertRaises( + ImportError, runner.resolve, "nonexistent", "nonexistent_function" + ) + + def test_nonexistent_function(self): + self.assertRaisesRegexp( + AttributeError, + r"has no attribute 'nonexistent_function'", + runner.resolve, + "os.path", + "nonexistent_function", + ) + + def test_simple_happy_path(self): + from os.path import exists + + self.assertIs(runner.resolve("os.path", "exists"), exists) + + def test_complex_happy_path(self): + # Ensure we can recursively resolve object attributes if necessary. + self.assertEquals(runner.resolve("os.path", "exists.__name__"), "exists") + + +class Test_run(unittest.TestCase): + def match_output(self, argv, code, regex): + argv = ["waitress-serve"] + argv + with capture() as captured: + self.assertEqual(runner.run(argv=argv), code) + self.assertRegexpMatches(captured.getvalue(), regex) + captured.close() + + def test_bad(self): + self.match_output(["--bad-opt"], 1, "^Error: option --bad-opt not recognized") + + def test_help(self): + self.match_output(["--help"], 0, "^Usage:\n\n waitress-serve") + + def test_no_app(self): + self.match_output([], 1, "^Error: Specify one application only") + + def test_multiple_apps_app(self): + self.match_output(["a:a", "b:b"], 1, "^Error: Specify one application only") + + def test_bad_apps_app(self): + self.match_output(["a"], 1, "^Error: Malformed application 'a'") + + def test_bad_app_module(self): + self.match_output(["nonexistent:a"], 1, "^Error: Bad module 'nonexistent'") + + self.match_output( + ["nonexistent:a"], + 1, + ( + r"There was an exception \((ImportError|ModuleNotFoundError)\) " + "importing your module.\n\nIt had these arguments: \n" + "1. No module named '?nonexistent'?" + ), + ) + + def test_cwd_added_to_path(self): + def null_serve(app, **kw): + pass + + sys_path = sys.path + current_dir = os.getcwd() + try: + os.chdir(os.path.dirname(__file__)) + argv = [ + "waitress-serve", + "fixtureapps.runner:app", + ] + self.assertEqual(runner.run(argv=argv, _serve=null_serve), 0) + finally: + sys.path = sys_path + os.chdir(current_dir) + + def test_bad_app_object(self): + self.match_output( + ["tests.fixtureapps.runner:a"], 1, "^Error: Bad object name 'a'" + ) + + def test_simple_call(self): + import tests.fixtureapps.runner as _apps + + def check_server(app, **kw): + self.assertIs(app, _apps.app) + self.assertDictEqual(kw, {"port": "80"}) + + argv = [ + "waitress-serve", + "--port=80", + "tests.fixtureapps.runner:app", + ] + self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) + + def test_returned_app(self): + import tests.fixtureapps.runner as _apps + + def check_server(app, **kw): + self.assertIs(app, _apps.app) + self.assertDictEqual(kw, {"port": "80"}) + + argv = [ + "waitress-serve", + "--port=80", + "--call", + "tests.fixtureapps.runner:returns_app", + ] + self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) + + +class Test_helper(unittest.TestCase): + def test_exception_logging(self): + from waitress.runner import show_exception + + regex = ( + r"There was an exception \(ImportError\) importing your module." + r"\n\nIt had these arguments: \n1. My reason" + ) + + with capture() as captured: + try: + raise ImportError("My reason") + except ImportError: + self.assertEqual(show_exception(sys.stderr), None) + self.assertRegexpMatches(captured.getvalue(), regex) + captured.close() + + regex = ( + r"There was an exception \(ImportError\) importing your module." + r"\n\nIt had no arguments." + ) + + with capture() as captured: + try: + raise ImportError + except ImportError: + self.assertEqual(show_exception(sys.stderr), None) + self.assertRegexpMatches(captured.getvalue(), regex) + captured.close() + + +@contextlib.contextmanager +def capture(): + from waitress.compat import NativeIO + + fd = NativeIO() + sys.stdout = fd + sys.stderr = fd + yield fd + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..9134fb8 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,533 @@ +import errno +import socket +import unittest + +dummy_app = object() + + +class TestWSGIServer(unittest.TestCase): + def _makeOne( + self, + application=dummy_app, + host="127.0.0.1", + port=0, + _dispatcher=None, + adj=None, + map=None, + _start=True, + _sock=None, + _server=None, + ): + from waitress.server import create_server + + self.inst = create_server( + application, + host=host, + port=port, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + ) + return self.inst + + def _makeOneWithMap( + self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app + ): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + return self._makeOne( + app, + host=host, + port=port, + map=map, + _sock=sock, + _dispatcher=task_dispatcher, + _start=_start, + ) + + def _makeOneWithMulti( + self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0" + ): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + from waitress.server import create_server + + self.inst = create_server( + app, + listen=listen, + map=map, + _dispatcher=task_dispatcher, + _start=_start, + _sock=sock, + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + if self.inst is not None: + self.inst.close() + + def test_ctor_app_is_None(self): + self.inst = None + self.assertRaises(ValueError, self._makeOneWithMap, app=None) + + def test_ctor_start_true(self): + inst = self._makeOneWithMap(_start=True) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + + def test_ctor_makes_dispatcher(self): + inst = self._makeOne(_start=False, map={}) + self.assertEqual( + inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher" + ) + + def test_ctor_start_false(self): + inst = self._makeOneWithMap(_start=False) + self.assertEqual(inst.accepting, False) + + def test_get_server_name_empty(self): + inst = self._makeOneWithMap(_start=False) + self.assertRaises(ValueError, inst.get_server_name, "") + + def test_get_server_name_with_ip(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("127.0.0.1") + self.assertTrue(result) + + def test_get_server_name_with_hostname(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("fred.flintstone.com") + self.assertEqual(result, "fred.flintstone.com") + + def test_get_server_name_0000(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("0.0.0.0") + self.assertTrue(len(result) != 0) + + def test_get_server_name_double_colon(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("::") + self.assertTrue(len(result) != 0) + + def test_get_server_name_ipv6(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("2001:DB8::ffff") + self.assertEqual("[2001:DB8::ffff]", result) + + def test_get_server_multi(self): + inst = self._makeOneWithMulti() + self.assertEqual(inst.__class__.__name__, "MultiSocketServer") + + def test_run(self): + inst = self._makeOneWithMap(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + + def test_run_base_server(self): + inst = self._makeOneWithMulti(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + + def test_pull_trigger(self): + inst = self._makeOneWithMap(_start=False) + inst.trigger.close() + inst.trigger = DummyTrigger() + inst.pull_trigger() + self.assertEqual(inst.trigger.pulled, True) + + def test_add_task(self): + task = DummyTask() + inst = self._makeOneWithMap() + inst.add_task(task) + self.assertEqual(inst.task_dispatcher.tasks, [task]) + self.assertFalse(task.serviced) + + def test_readable_not_accepting(self): + inst = self._makeOneWithMap() + inst.accepting = False + self.assertFalse(inst.readable()) + + def test_readable_maplen_gt_connection_limit(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {"a": 1, "b": 2} + self.assertFalse(inst.readable()) + + def test_readable_maplen_lt_connection_limit(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {} + self.assertTrue(inst.readable()) + + def test_readable_maintenance_false(self): + import time + + inst = self._makeOneWithMap() + then = time.time() + 1000 + inst.next_channel_cleanup = then + L = [] + inst.maintenance = lambda t: L.append(t) + inst.readable() + self.assertEqual(L, []) + self.assertEqual(inst.next_channel_cleanup, then) + + def test_readable_maintenance_true(self): + inst = self._makeOneWithMap() + inst.next_channel_cleanup = 0 + L = [] + inst.maintenance = lambda t: L.append(t) + inst.readable() + self.assertEqual(len(L), 1) + self.assertNotEqual(inst.next_channel_cleanup, 0) + + def test_writable(self): + inst = self._makeOneWithMap() + self.assertFalse(inst.writable()) + + def test_handle_read(self): + inst = self._makeOneWithMap() + self.assertEqual(inst.handle_read(), None) + + def test_handle_connect(self): + inst = self._makeOneWithMap() + self.assertEqual(inst.handle_connect(), None) + + def test_handle_accept_wouldblock_socket_error(self): + inst = self._makeOneWithMap() + ewouldblock = socket.error(errno.EWOULDBLOCK) + inst.socket = DummySock(toraise=ewouldblock) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, False) + + def test_handle_accept_other_socket_error(self): + inst = self._makeOneWithMap() + eaborted = socket.error(errno.ECONNABORTED) + inst.socket = DummySock(toraise=eaborted) + inst.adj = DummyAdj + + def foo(): + raise socket.error + + inst.accept = foo + inst.logger = DummyLogger() + inst.handle_accept() + self.assertEqual(inst.socket.accepted, False) + self.assertEqual(len(inst.logger.logged), 1) + + def test_handle_accept_noerror(self): + inst = self._makeOneWithMap() + innersock = DummySock() + inst.socket = DummySock(acceptresult=(innersock, None)) + inst.adj = DummyAdj + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + def test_maintenance(self): + inst = self._makeOneWithMap() + + class DummyChannel(object): + requests = [] + + zombie = DummyChannel() + zombie.last_activity = 0 + zombie.running_tasks = False + inst.active_channels[100] = zombie + inst.maintenance(10000) + self.assertEqual(zombie.will_close, True) + + def test_backward_compatibility(self): + from waitress.server import WSGIServer, TcpWSGIServer + from waitress.adjustments import Adjustments + + self.assertTrue(WSGIServer is TcpWSGIServer) + self.inst = WSGIServer(None, _start=False, port=1234) + # Ensure the adjustment was actually applied. + self.assertNotEqual(Adjustments.port, 1234) + self.assertEqual(self.inst.adj.port, 1234) + + def test_create_with_one_tcp_socket(self): + from waitress.server import TcpWSGIServer + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, TcpWSGIServer)) + + def test_create_with_multiple_tcp_sockets(self): + from waitress.server import MultiSocketServer + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + sockets[0].bind(("127.0.0.1", 0)) + sockets[1].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, MultiSocketServer)) + self.assertEqual(len(inst.effective_listen), 2) + + def test_create_with_one_socket_should_not_bind_socket(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + sockets[0].bind_called = False + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertEqual(inst.socket.bound, ("127.0.0.1", 80)) + self.assertFalse(inst.socket.bind_called) + + def test_create_with_one_socket_handle_accept_noerror(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + inst = self._makeWithSockets(sockets=sockets) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.adj = DummyAdj + inst.handle_accept() + self.assertEqual(sockets[0].accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + +if hasattr(socket, "AF_UNIX"): + + class TestUnixWSGIServer(unittest.TestCase): + unix_socket = "/tmp/waitress.test.sock" + + def _makeOne(self, _start=True, _sock=None): + from waitress.server import create_server + + self.inst = create_server( + dummy_app, + map={}, + _start=_start, + _sock=_sock, + _dispatcher=DummyTaskDispatcher(), + unix_socket=self.unix_socket, + unix_socket_perms="600", + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + self.inst.close() + + def _makeDummy(self, *args, **kwargs): + sock = DummySock(*args, **kwargs) + sock.family = socket.AF_UNIX + return sock + + def test_unix(self): + inst = self._makeOne(_start=False) + self.assertEqual(inst.socket.family, socket.AF_UNIX) + self.assertEqual(inst.socket.getsockname(), self.unix_socket) + + def test_handle_accept(self): + # Working on the assumption that we only have to test the happy path + # for Unix domain sockets as the other paths should've been covered + # by inet sockets. + client = self._makeDummy() + listen = self._makeDummy(acceptresult=(client, None)) + inst = self._makeOne(_sock=listen) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(client.opts, []) + self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) + + def test_creates_new_sockinfo(self): + from waitress.server import UnixWSGIServer + + self.inst = UnixWSGIServer( + dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600" + ) + + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) + + def test_create_with_unix_socket(self): + from waitress.server import ( + MultiSocketServer, + BaseWSGIServer, + TcpWSGIServer, + UnixWSGIServer, + ) + + sockets = [ + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list( + filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) + ) + self.assertTrue(isinstance(server[0], UnixWSGIServer)) + self.assertTrue(isinstance(server[1], UnixWSGIServer)) + + +class DummySock(socket.socket): + accepted = False + blocking = False + family = socket.AF_INET + type = socket.SOCK_STREAM + proto = 0 + + def __init__(self, toraise=None, acceptresult=(None, None)): + self.toraise = toraise + self.acceptresult = acceptresult + self.bound = None + self.opts = [] + self.bind_called = False + + def bind(self, addr): + self.bind_called = True + self.bound = addr + + def accept(self): + if self.toraise: + raise self.toraise + self.accepted = True + return self.acceptresult + + def setblocking(self, x): + self.blocking = True + + def fileno(self): + return 10 + + def getpeername(self): + return "127.0.0.1" + + def setsockopt(self, *arg): + self.opts.append(arg) + + def getsockopt(self, *arg): + return 1 + + def listen(self, num): + self.listened = num + + def getsockname(self): + return self.bound + + def close(self): + pass + + +class DummyTaskDispatcher(object): + def __init__(self): + self.tasks = [] + + def add_task(self, task): + self.tasks.append(task) + + def shutdown(self): + self.was_shutdown = True + + +class DummyTask(object): + serviced = False + start_response_called = False + wrote_header = False + status = "200 OK" + + def __init__(self): + self.response_headers = {} + self.written = "" + + def service(self): # pragma: no cover + self.serviced = True + + +class DummyAdj: + connection_limit = 1 + log_socket_errors = True + socket_options = [("level", "optname", "value")] + cleanup_interval = 900 + channel_timeout = 300 + + +class DummyAsyncore(object): + def loop(self, timeout=30.0, use_poll=False, map=None, count=None): + raise SystemExit + + +class DummyTrigger(object): + def pull_trigger(self): + self.pulled = True + + def close(self): + pass + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, **kw): + self.logged.append(msg) diff --git a/tests/test_task.py b/tests/test_task.py new file mode 100644 index 0000000..1a86245 --- /dev/null +++ b/tests/test_task.py @@ -0,0 +1,1001 @@ +import unittest +import io + + +class TestThreadedTaskDispatcher(unittest.TestCase): + def _makeOne(self): + from waitress.task import ThreadedTaskDispatcher + + return ThreadedTaskDispatcher() + + def test_handler_thread_task_raises(self): + inst = self._makeOne() + inst.threads.add(0) + inst.logger = DummyLogger() + + class BadDummyTask(DummyTask): + def service(self): + super(BadDummyTask, self).service() + inst.stop_count += 1 + raise Exception + + task = BadDummyTask() + inst.logger = DummyLogger() + inst.queue.append(task) + inst.active_count += 1 + inst.handler_thread(0) + self.assertEqual(inst.stop_count, 0) + self.assertEqual(inst.active_count, 0) + self.assertEqual(inst.threads, set()) + self.assertEqual(len(inst.logger.logged), 1) + + def test_set_thread_count_increase(self): + inst = self._makeOne() + L = [] + inst.start_new_thread = lambda *x: L.append(x) + inst.set_thread_count(1) + self.assertEqual(L, [(inst.handler_thread, (0,))]) + + def test_set_thread_count_increase_with_existing(self): + inst = self._makeOne() + L = [] + inst.threads = {0} + inst.start_new_thread = lambda *x: L.append(x) + inst.set_thread_count(2) + self.assertEqual(L, [(inst.handler_thread, (1,))]) + + def test_set_thread_count_decrease(self): + inst = self._makeOne() + inst.threads = {0, 1} + inst.set_thread_count(1) + self.assertEqual(inst.stop_count, 1) + + def test_set_thread_count_same(self): + inst = self._makeOne() + L = [] + inst.start_new_thread = lambda *x: L.append(x) + inst.threads = {0} + inst.set_thread_count(1) + self.assertEqual(L, []) + + def test_add_task_with_idle_threads(self): + task = DummyTask() + inst = self._makeOne() + inst.threads.add(0) + inst.queue_logger = DummyLogger() + inst.add_task(task) + self.assertEqual(len(inst.queue), 1) + self.assertEqual(len(inst.queue_logger.logged), 0) + + def test_add_task_with_all_busy_threads(self): + task = DummyTask() + inst = self._makeOne() + inst.queue_logger = DummyLogger() + inst.add_task(task) + self.assertEqual(len(inst.queue_logger.logged), 1) + inst.add_task(task) + self.assertEqual(len(inst.queue_logger.logged), 2) + + def test_shutdown_one_thread(self): + inst = self._makeOne() + inst.threads.add(0) + inst.logger = DummyLogger() + task = DummyTask() + inst.queue.append(task) + self.assertEqual(inst.shutdown(timeout=0.01), True) + self.assertEqual( + inst.logger.logged, + ["1 thread(s) still running", "Canceling 1 pending task(s)",], + ) + self.assertEqual(task.cancelled, True) + + def test_shutdown_no_threads(self): + inst = self._makeOne() + self.assertEqual(inst.shutdown(timeout=0.01), True) + + def test_shutdown_no_cancel_pending(self): + inst = self._makeOne() + self.assertEqual(inst.shutdown(cancel_pending=False, timeout=0.01), False) + + +class TestTask(unittest.TestCase): + def _makeOne(self, channel=None, request=None): + if channel is None: + channel = DummyChannel() + if request is None: + request = DummyParser() + from waitress.task import Task + + return Task(channel, request) + + def test_ctor_version_not_in_known(self): + request = DummyParser() + request.version = "8.4" + inst = self._makeOne(request=request) + self.assertEqual(inst.version, "1.0") + + def test_build_response_header_bad_http_version(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "8.4" + self.assertRaises(AssertionError, inst.build_response_header) + + def test_build_response_header_v10_keepalive_no_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.request.headers["CONNECTION"] = "keep-alive" + inst.version = "1.0" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v10_keepalive_with_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.request.headers["CONNECTION"] = "keep-alive" + inst.response_headers = [("Content-Length", "10")] + inst.version = "1.0" + inst.content_length = 0 + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: Keep-Alive") + self.assertEqual(lines[2], b"Content-Length: 10") + self.assertTrue(lines[3].startswith(b"Date:")) + self.assertEqual(lines[4], b"Server: waitress") + self.assertEqual(inst.close_on_finish, False) + + def test_build_response_header_v11_connection_closed_by_client(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "close" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertTrue(("Connection", "close") in inst.response_headers) + self.assertEqual(inst.close_on_finish, True) + + def test_build_response_header_v11_connection_keepalive_by_client(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.request.headers["CONNECTION"] = "keep-alive" + inst.version = "1.1" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertTrue(("Connection", "close") in inst.response_headers) + self.assertEqual(inst.close_on_finish, True) + + def test_build_response_header_v11_200_no_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_204_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx or 204. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "204 No Content" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 204 No Content") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx or 204. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "100 Continue" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 100 Continue") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "304 Not Modified" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 304 Not Modified") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_via_added(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.0" + inst.response_headers = [("Server", "abc")] + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: abc") + self.assertEqual(lines[4], b"Via: waitress") + + def test_build_response_header_date_exists(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.0" + inst.response_headers = [("Date", "date")] + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + + def test_build_response_header_preexisting_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.content_length = 100 + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Content-Length: 100") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + + def test_remove_content_length_header(self): + inst = self._makeOne() + inst.response_headers = [("Content-Length", "70")] + inst.remove_content_length_header() + self.assertEqual(inst.response_headers, []) + + def test_remove_content_length_header_with_other(self): + inst = self._makeOne() + inst.response_headers = [ + ("Content-Length", "70"), + ("Content-Type", "text/html"), + ] + inst.remove_content_length_header() + self.assertEqual(inst.response_headers, [("Content-Type", "text/html")]) + + def test_start(self): + inst = self._makeOne() + inst.start() + self.assertTrue(inst.start_time) + + def test_finish_didnt_write_header(self): + inst = self._makeOne() + inst.wrote_header = False + inst.complete = True + inst.finish() + self.assertTrue(inst.channel.written) + + def test_finish_wrote_header(self): + inst = self._makeOne() + inst.wrote_header = True + inst.finish() + self.assertFalse(inst.channel.written) + + def test_finish_chunked_response(self): + inst = self._makeOne() + inst.wrote_header = True + inst.chunked_response = True + inst.finish() + self.assertEqual(inst.channel.written, b"0\r\n\r\n") + + def test_write_wrote_header(self): + inst = self._makeOne() + inst.wrote_header = True + inst.complete = True + inst.content_length = 3 + inst.write(b"abc") + self.assertEqual(inst.channel.written, b"abc") + + def test_write_header_not_written(self): + inst = self._makeOne() + inst.wrote_header = False + inst.complete = True + inst.write(b"abc") + self.assertTrue(inst.channel.written) + self.assertEqual(inst.wrote_header, True) + + def test_write_start_response_uncalled(self): + inst = self._makeOne() + self.assertRaises(RuntimeError, inst.write, b"") + + def test_write_chunked_response(self): + inst = self._makeOne() + inst.wrote_header = True + inst.chunked_response = True + inst.complete = True + inst.write(b"abc") + self.assertEqual(inst.channel.written, b"3\r\nabc\r\n") + + def test_write_preexisting_content_length(self): + inst = self._makeOne() + inst.wrote_header = True + inst.complete = True + inst.content_length = 1 + inst.logger = DummyLogger() + inst.write(b"abc") + self.assertTrue(inst.channel.written) + self.assertEqual(inst.logged_write_excess, True) + self.assertEqual(len(inst.logger.logged), 1) + + +class TestWSGITask(unittest.TestCase): + def _makeOne(self, channel=None, request=None): + if channel is None: + channel = DummyChannel() + if request is None: + request = DummyParser() + from waitress.task import WSGITask + + return WSGITask(channel, request) + + def test_service(self): + inst = self._makeOne() + + def execute(): + inst.executed = True + + inst.execute = execute + inst.complete = True + inst.service() + self.assertTrue(inst.start_time) + self.assertTrue(inst.close_on_finish) + self.assertTrue(inst.channel.written) + self.assertEqual(inst.executed, True) + + def test_service_server_raises_socket_error(self): + import socket + + inst = self._makeOne() + + def execute(): + raise socket.error + + inst.execute = execute + self.assertRaises(socket.error, inst.service) + self.assertTrue(inst.start_time) + self.assertTrue(inst.close_on_finish) + self.assertFalse(inst.channel.written) + + def test_execute_app_calls_start_response_twice_wo_exc_info(self): + def app(environ, start_response): + start_response("200 OK", []) + start_response("200 OK", []) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_app_calls_start_response_w_exc_info_complete(self): + def app(environ, start_response): + start_response("200 OK", [], [ValueError, ValueError(), None]) + return [b"a"] + + inst = self._makeOne() + inst.complete = True + inst.channel.server.application = app + inst.execute() + self.assertTrue(inst.complete) + self.assertEqual(inst.status, "200 OK") + self.assertTrue(inst.channel.written) + + def test_execute_app_calls_start_response_w_excinf_headers_unwritten(self): + def app(environ, start_response): + start_response("200 OK", [], [ValueError, None, None]) + return [b"a"] + + inst = self._makeOne() + inst.wrote_header = False + inst.channel.server.application = app + inst.response_headers = [("a", "b")] + inst.execute() + self.assertTrue(inst.complete) + self.assertEqual(inst.status, "200 OK") + self.assertTrue(inst.channel.written) + self.assertFalse(("a", "b") in inst.response_headers) + + def test_execute_app_calls_start_response_w_excinf_headers_written(self): + def app(environ, start_response): + start_response("200 OK", [], [ValueError, ValueError(), None]) + + inst = self._makeOne() + inst.complete = True + inst.wrote_header = True + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_header_key(self): + def app(environ, start_response): + start_response("200 OK", [(None, "a")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_bad_header_value(self): + def app(environ, start_response): + start_response("200 OK", [("a", None)]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_hopbyhop_header(self): + def app(environ, start_response): + start_response("200 OK", [("Connection", "close")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_bad_header_value_control_characters(self): + def app(environ, start_response): + start_response("200 OK", [("a", "\n")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_header_name_control_characters(self): + def app(environ, start_response): + start_response("200 OK", [("a\r", "value")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_status_control_characters(self): + def app(environ, start_response): + start_response("200 OK\r", []) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_preserve_header_value_order(self): + def app(environ, start_response): + write = start_response("200 OK", [("C", "b"), ("A", "b"), ("A", "a")]) + write(b"abc") + return [] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertTrue(b"A: b\r\nA: a\r\nC: b\r\n" in inst.channel.written) + + def test_execute_bad_status_value(self): + def app(environ, start_response): + start_response(None, []) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_with_content_length_header(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "1")]) + return [b"a"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.content_length, 1) + + def test_execute_app_calls_write(self): + def app(environ, start_response): + write = start_response("200 OK", [("Content-Length", "3")]) + write(b"abc") + return [] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.channel.written[-3:], b"abc") + + def test_execute_app_returns_len1_chunk_without_cl(self): + def app(environ, start_response): + start_response("200 OK", []) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.content_length, 3) + + def test_execute_app_returns_empty_chunk_as_first(self): + def app(environ, start_response): + start_response("200 OK", []) + return ["", b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.content_length, None) + + def test_execute_app_returns_too_many_bytes(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "1")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_returns_too_few_bytes(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return [b"a"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_do_not_warn_on_head(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return [b""] + + inst = self._makeOne() + inst.request.command = "HEAD" + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertEqual(len(inst.logger.logged), 0) + + def test_execute_app_without_body_204_logged(self): + def app(environ, start_response): + start_response("204 No Content", [("Content-Length", "3")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertNotIn(b"abc", inst.channel.written) + self.assertNotIn(b"Content-Length", inst.channel.written) + self.assertNotIn(b"Transfer-Encoding", inst.channel.written) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_without_body_304_logged(self): + def app(environ, start_response): + start_response("304 Not Modified", [("Content-Length", "3")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertNotIn(b"abc", inst.channel.written) + self.assertNotIn(b"Content-Length", inst.channel.written) + self.assertNotIn(b"Transfer-Encoding", inst.channel.written) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_returns_closeable(self): + class closeable(list): + def close(self): + self.closed = True + + foo = closeable([b"abc"]) + + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return foo + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(foo.closed, True) + + def test_execute_app_returns_filewrapper_prepare_returns_True(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + app_iter = ReadOnlyFileBasedBuffer(f, 8192) + + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return app_iter + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertTrue(inst.channel.written) # header + self.assertEqual(inst.channel.otherdata, [app_iter]) + + def test_execute_app_returns_filewrapper_prepare_returns_True_nocl(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + app_iter = ReadOnlyFileBasedBuffer(f, 8192) + + def app(environ, start_response): + start_response("200 OK", []) + return app_iter + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertTrue(inst.channel.written) # header + self.assertEqual(inst.channel.otherdata, [app_iter]) + self.assertEqual(inst.content_length, 3) + + def test_execute_app_returns_filewrapper_prepare_returns_True_badcl(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + app_iter = ReadOnlyFileBasedBuffer(f, 8192) + + def app(environ, start_response): + start_response("200 OK", []) + return app_iter + + inst = self._makeOne() + inst.channel.server.application = app + inst.content_length = 10 + inst.response_headers = [("Content-Length", "10")] + inst.execute() + self.assertTrue(inst.channel.written) # header + self.assertEqual(inst.channel.otherdata, [app_iter]) + self.assertEqual(inst.content_length, 3) + self.assertEqual(dict(inst.response_headers)["Content-Length"], "3") + + def test_get_environment_already_cached(self): + inst = self._makeOne() + inst.environ = object() + self.assertEqual(inst.get_environment(), inst.environ) + + def test_get_environment_path_startswith_more_than_one_slash(self): + inst = self._makeOne() + request = DummyParser() + request.path = "///abc" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "/abc") + + def test_get_environment_path_empty(self): + inst = self._makeOne() + request = DummyParser() + request.path = "" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "") + + def test_get_environment_no_query(self): + inst = self._makeOne() + request = DummyParser() + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["QUERY_STRING"], "") + + def test_get_environment_with_query(self): + inst = self._makeOne() + request = DummyParser() + request.query = "abc" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["QUERY_STRING"], "abc") + + def test_get_environ_with_url_prefix_miss(self): + inst = self._makeOne() + inst.channel.server.adj.url_prefix = "/foo" + request = DummyParser() + request.path = "/bar" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "/bar") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") + + def test_get_environ_with_url_prefix_hit(self): + inst = self._makeOne() + inst.channel.server.adj.url_prefix = "/foo" + request = DummyParser() + request.path = "/foo/fuz" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "/fuz") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") + + def test_get_environ_with_url_prefix_empty_path(self): + inst = self._makeOne() + inst.channel.server.adj.url_prefix = "/foo" + request = DummyParser() + request.path = "/foo" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") + + def test_get_environment_values(self): + import sys + + inst = self._makeOne() + request = DummyParser() + request.headers = { + "CONTENT_TYPE": "abc", + "CONTENT_LENGTH": "10", + "X_FOO": "BAR", + "CONNECTION": "close", + } + request.query = "abc" + inst.request = request + environ = inst.get_environment() + + # nail the keys of environ + self.assertEqual( + sorted(environ.keys()), + [ + "CONTENT_LENGTH", + "CONTENT_TYPE", + "HTTP_CONNECTION", + "HTTP_X_FOO", + "PATH_INFO", + "QUERY_STRING", + "REMOTE_ADDR", + "REMOTE_HOST", + "REMOTE_PORT", + "REQUEST_METHOD", + "SCRIPT_NAME", + "SERVER_NAME", + "SERVER_PORT", + "SERVER_PROTOCOL", + "SERVER_SOFTWARE", + "wsgi.errors", + "wsgi.file_wrapper", + "wsgi.input", + "wsgi.input_terminated", + "wsgi.multiprocess", + "wsgi.multithread", + "wsgi.run_once", + "wsgi.url_scheme", + "wsgi.version", + ], + ) + + self.assertEqual(environ["REQUEST_METHOD"], "GET") + self.assertEqual(environ["SERVER_PORT"], "80") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["SERVER_SOFTWARE"], "waitress") + self.assertEqual(environ["SERVER_PROTOCOL"], "HTTP/1.0") + self.assertEqual(environ["SCRIPT_NAME"], "") + self.assertEqual(environ["HTTP_CONNECTION"], "close") + self.assertEqual(environ["PATH_INFO"], "/") + self.assertEqual(environ["QUERY_STRING"], "abc") + self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") + self.assertEqual(environ["REMOTE_HOST"], "127.0.0.1") + self.assertEqual(environ["REMOTE_PORT"], "39830") + self.assertEqual(environ["CONTENT_TYPE"], "abc") + self.assertEqual(environ["CONTENT_LENGTH"], "10") + self.assertEqual(environ["HTTP_X_FOO"], "BAR") + self.assertEqual(environ["wsgi.version"], (1, 0)) + self.assertEqual(environ["wsgi.url_scheme"], "http") + self.assertEqual(environ["wsgi.errors"], sys.stderr) + self.assertEqual(environ["wsgi.multithread"], True) + self.assertEqual(environ["wsgi.multiprocess"], False) + self.assertEqual(environ["wsgi.run_once"], False) + self.assertEqual(environ["wsgi.input"], "stream") + self.assertEqual(environ["wsgi.input_terminated"], True) + self.assertEqual(inst.environ, environ) + + +class TestErrorTask(unittest.TestCase): + def _makeOne(self, channel=None, request=None): + if channel is None: + channel = DummyChannel() + if request is None: + request = DummyParser() + request.error = self._makeDummyError() + from waitress.task import ErrorTask + + return ErrorTask(channel, request) + + def _makeDummyError(self): + from waitress.utilities import Error + + e = Error("body") + e.code = 432 + e.reason = "Too Ugly" + return e + + def test_execute_http_10(self): + inst = self._makeOne() + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.0 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + def test_execute_http_11(self): + inst = self._makeOne() + inst.version = "1.1" + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + def test_execute_http_11_close(self): + inst = self._makeOne() + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "close" + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + def test_execute_http_11_keep_forces_close(self): + inst = self._makeOne() + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "keep-alive" + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + +class DummyTask(object): + serviced = False + cancelled = False + + def service(self): + self.serviced = True + + def cancel(self): + self.cancelled = True + + +class DummyAdj(object): + log_socket_errors = True + ident = "waitress" + host = "127.0.0.1" + port = 80 + url_prefix = "" + + +class DummyServer(object): + server_name = "localhost" + effective_port = 80 + + def __init__(self): + self.adj = DummyAdj() + + +class DummyChannel(object): + closed_when_done = False + adj = DummyAdj() + creation_time = 0 + addr = ("127.0.0.1", 39830) + + def __init__(self, server=None): + if server is None: + server = DummyServer() + self.server = server + self.written = b"" + self.otherdata = [] + + def write_soon(self, data): + if isinstance(data, bytes): + self.written += data + else: + self.otherdata.append(data) + return len(data) + + +class DummyParser(object): + version = "1.0" + command = "GET" + path = "/" + query = "" + url_scheme = "http" + expect_continue = False + headers_finished = False + + def __init__(self): + self.headers = {} + + def get_body_stream(self): + return "stream" + + +def filter_lines(s): + return list(filter(None, s.split(b"\r\n"))) + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, *args): + self.logged.append(msg % args) + + def exception(self, msg, *args): + self.logged.append(msg % args) diff --git a/tests/test_trigger.py b/tests/test_trigger.py new file mode 100644 index 0000000..af740f6 --- /dev/null +++ b/tests/test_trigger.py @@ -0,0 +1,111 @@ +import unittest +import os +import sys + +if not sys.platform.startswith("win"): + + class Test_trigger(unittest.TestCase): + def _makeOne(self, map): + from waitress.trigger import trigger + + self.inst = trigger(map) + return self.inst + + def tearDown(self): + self.inst.close() # prevent __del__ warning from file_dispatcher + + def test__close(self): + map = {} + inst = self._makeOne(map) + fd1, fd2 = inst._fds + inst.close() + self.assertRaises(OSError, os.read, fd1, 1) + self.assertRaises(OSError, os.read, fd2, 1) + + def test__physical_pull(self): + map = {} + inst = self._makeOne(map) + inst._physical_pull() + r = os.read(inst._fds[0], 1) + self.assertEqual(r, b"x") + + def test_readable(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.readable(), True) + + def test_writable(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.writable(), False) + + def test_handle_connect(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.handle_connect(), None) + + def test_close(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.close(), None) + self.assertEqual(inst._closed, True) + + def test_handle_close(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.handle_close(), None) + self.assertEqual(inst._closed, True) + + def test_pull_trigger_nothunk(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.pull_trigger(), None) + r = os.read(inst._fds[0], 1) + self.assertEqual(r, b"x") + + def test_pull_trigger_thunk(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.pull_trigger(True), None) + self.assertEqual(len(inst.thunks), 1) + r = os.read(inst._fds[0], 1) + self.assertEqual(r, b"x") + + def test_handle_read_socket_error(self): + map = {} + inst = self._makeOne(map) + result = inst.handle_read() + self.assertEqual(result, None) + + def test_handle_read_no_socket_error(self): + map = {} + inst = self._makeOne(map) + inst.pull_trigger() + result = inst.handle_read() + self.assertEqual(result, None) + + def test_handle_read_thunk(self): + map = {} + inst = self._makeOne(map) + inst.pull_trigger() + L = [] + inst.thunks = [lambda: L.append(True)] + result = inst.handle_read() + self.assertEqual(result, None) + self.assertEqual(L, [True]) + self.assertEqual(inst.thunks, []) + + def test_handle_read_thunk_error(self): + map = {} + inst = self._makeOne(map) + + def errorthunk(): + raise ValueError + + inst.pull_trigger(errorthunk) + L = [] + inst.log_info = lambda *arg: L.append(arg) + result = inst.handle_read() + self.assertEqual(result, None) + self.assertEqual(len(L), 1) + self.assertEqual(inst.thunks, []) diff --git a/tests/test_utilities.py b/tests/test_utilities.py new file mode 100644 index 0000000..15cd24f --- /dev/null +++ b/tests/test_utilities.py @@ -0,0 +1,140 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import unittest + + +class Test_parse_http_date(unittest.TestCase): + def _callFUT(self, v): + from waitress.utilities import parse_http_date + + return parse_http_date(v) + + def test_rfc850(self): + val = "Tuesday, 08-Feb-94 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, 760716929) + + def test_rfc822(self): + val = "Sun, 08 Feb 1994 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, 760716929) + + def test_neither(self): + val = "" + result = self._callFUT(val) + self.assertEqual(result, 0) + + +class Test_build_http_date(unittest.TestCase): + def test_rountdrip(self): + from waitress.utilities import build_http_date, parse_http_date + from time import time + + t = int(time()) + self.assertEqual(t, parse_http_date(build_http_date(t))) + + +class Test_unpack_rfc850(unittest.TestCase): + def _callFUT(self, val): + from waitress.utilities import unpack_rfc850, rfc850_reg + + return unpack_rfc850(rfc850_reg.match(val.lower())) + + def test_it(self): + val = "Tuesday, 08-Feb-94 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) + + +class Test_unpack_rfc_822(unittest.TestCase): + def _callFUT(self, val): + from waitress.utilities import unpack_rfc822, rfc822_reg + + return unpack_rfc822(rfc822_reg.match(val.lower())) + + def test_it(self): + val = "Sun, 08 Feb 1994 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) + + +class Test_find_double_newline(unittest.TestCase): + def _callFUT(self, val): + from waitress.utilities import find_double_newline + + return find_double_newline(val) + + def test_empty(self): + self.assertEqual(self._callFUT(b""), -1) + + def test_one_linefeed(self): + self.assertEqual(self._callFUT(b"\n"), -1) + + def test_double_linefeed(self): + self.assertEqual(self._callFUT(b"\n\n"), -1) + + def test_one_crlf(self): + self.assertEqual(self._callFUT(b"\r\n"), -1) + + def test_double_crfl(self): + self.assertEqual(self._callFUT(b"\r\n\r\n"), 4) + + def test_mixed(self): + self.assertEqual(self._callFUT(b"\n\n00\r\n\r\n"), 8) + + +class TestBadRequest(unittest.TestCase): + def _makeOne(self): + from waitress.utilities import BadRequest + + return BadRequest(1) + + def test_it(self): + inst = self._makeOne() + self.assertEqual(inst.body, 1) + + +class Test_undquote(unittest.TestCase): + def _callFUT(self, value): + from waitress.utilities import undquote + + return undquote(value) + + def test_empty(self): + self.assertEqual(self._callFUT(""), "") + + def test_quoted(self): + self.assertEqual(self._callFUT('"test"'), "test") + + def test_unquoted(self): + self.assertEqual(self._callFUT("test"), "test") + + def test_quoted_backslash_quote(self): + self.assertEqual(self._callFUT('"\\""'), '"') + + def test_quoted_htab(self): + self.assertEqual(self._callFUT('"\t"'), "\t") + + def test_quoted_backslash_htab(self): + self.assertEqual(self._callFUT('"\\\t"'), "\t") + + def test_quoted_backslash_invalid(self): + self.assertRaises(ValueError, self._callFUT, '"\\"') + + def test_invalid_quoting(self): + self.assertRaises(ValueError, self._callFUT, '"test') + + def test_invalid_quoting_single_quote(self): + self.assertRaises(ValueError, self._callFUT, '"') diff --git a/tests/test_wasyncore.py b/tests/test_wasyncore.py new file mode 100644 index 0000000..9c23509 --- /dev/null +++ b/tests/test_wasyncore.py @@ -0,0 +1,1761 @@ +from waitress import wasyncore as asyncore +from waitress import compat +import contextlib +import functools +import gc +import unittest +import select +import os +import socket +import sys +import time +import errno +import re +import struct +import threading +import warnings + +from io import BytesIO + +TIMEOUT = 3 +HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") +HOST = "localhost" +HOSTv4 = "127.0.0.1" +HOSTv6 = "::1" + +# Filename used for testing +if os.name == "java": # pragma: no cover + # Jython disallows @ in module names + TESTFN = "$test" +else: + TESTFN = "@test" + +TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid()) + + +class DummyLogger(object): # pragma: no cover + def __init__(self): + self.messages = [] + + def log(self, severity, message): + self.messages.append((severity, message)) + + +class WarningsRecorder(object): # pragma: no cover + """Convenience wrapper for the warnings list returned on + entry to the warnings.catch_warnings() context manager. + """ + + def __init__(self, warnings_list): + self._warnings = warnings_list + self._last = 0 + + @property + def warnings(self): + return self._warnings[self._last :] + + def reset(self): + self._last = len(self._warnings) + + +def _filterwarnings(filters, quiet=False): # pragma: no cover + """Catch the warnings, then check if all the expected + warnings have been raised and re-raise unexpected warnings. + If 'quiet' is True, only re-raise the unexpected warnings. + """ + # Clear the warning registry of the calling module + # in order to re-raise the warnings. + frame = sys._getframe(2) + registry = frame.f_globals.get("__warningregistry__") + if registry: + registry.clear() + with warnings.catch_warnings(record=True) as w: + # Set filter "always" to record all warnings. Because + # test_warnings swap the module, we need to look up in + # the sys.modules dictionary. + sys.modules["warnings"].simplefilter("always") + yield WarningsRecorder(w) + # Filter the recorded warnings + reraise = list(w) + missing = [] + for msg, cat in filters: + seen = False + for w in reraise[:]: + warning = w.message + # Filter out the matching messages + if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat): + seen = True + reraise.remove(w) + if not seen and not quiet: + # This filter caught nothing + missing.append((msg, cat.__name__)) + if reraise: + raise AssertionError("unhandled warning %s" % reraise[0]) + if missing: + raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0]) + + +@contextlib.contextmanager +def check_warnings(*filters, **kwargs): # pragma: no cover + """Context manager to silence warnings. + + Accept 2-tuples as positional arguments: + ("message regexp", WarningCategory) + + Optional argument: + - if 'quiet' is True, it does not fail if a filter catches nothing + (default True without argument, + default False if some filters are defined) + + Without argument, it defaults to: + check_warnings(("", Warning), quiet=True) + """ + quiet = kwargs.get("quiet") + if not filters: + filters = (("", Warning),) + # Preserve backward compatibility + if quiet is None: + quiet = True + return _filterwarnings(filters, quiet) + + +def gc_collect(): # pragma: no cover + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if sys.platform.startswith("java"): + time.sleep(0.1) + gc.collect() + gc.collect() + + +def threading_setup(): # pragma: no cover + return (compat.thread._count(), None) + + +def threading_cleanup(*original_values): # pragma: no cover + global environment_altered + + _MAX_COUNT = 100 + + for count in range(_MAX_COUNT): + values = (compat.thread._count(), None) + if values == original_values: + break + + if not count: + # Display a warning at the first iteration + environment_altered = True + sys.stderr.write( + "Warning -- threading_cleanup() failed to cleanup " + "%s threads" % (values[0] - original_values[0]) + ) + sys.stderr.flush() + + values = None + + time.sleep(0.01) + gc_collect() + + +def reap_threads(func): # pragma: no cover + """Use this function when threads are being used. This will + ensure that the threads are cleaned up even when the test fails. + """ + + @functools.wraps(func) + def decorator(*args): + key = threading_setup() + try: + return func(*args) + finally: + threading_cleanup(*key) + + return decorator + + +def join_thread(thread, timeout=30.0): # pragma: no cover + """Join a thread. Raise an AssertionError if the thread is still alive + after timeout seconds. + """ + thread.join(timeout) + if thread.is_alive(): + msg = "failed to join the thread in %.1f seconds" % timeout + raise AssertionError(msg) + + +def bind_port(sock, host=HOST): # pragma: no cover + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, "SO_REUSEADDR"): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEADDR " + "socket option on TCP/IP sockets!" + ) + if hasattr(socket, "SO_REUSEPORT"): + try: + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEPORT " + "socket option on TCP/IP sockets!" + ) + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + if hasattr(socket, "SO_EXCLUSIVEADDRUSE"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + return port + + +@contextlib.contextmanager +def closewrapper(sock): # pragma: no cover + try: + yield sock + finally: + sock.close() + + +class dummysocket: # pragma: no cover + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def fileno(self): + return 42 + + def setblocking(self, yesno): + self.isblocking = yesno + + def getpeername(self): + return "peername" + + +class dummychannel: # pragma: no cover + def __init__(self): + self.socket = dummysocket() + + def close(self): + self.socket.close() + + +class exitingdummy: # pragma: no cover + def __init__(self): + pass + + def handle_read_event(self): + raise asyncore.ExitNow() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + +class crashingdummy: + def __init__(self): + self.error_handled = False + + def handle_read_event(self): + raise Exception() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + def handle_error(self): + self.error_handled = True + + +# used when testing senders; just collects what it gets until newline is sent +def capture_server(evt, buf, serv): # pragma no cover + try: + serv.listen(0) + conn, addr = serv.accept() + except socket.timeout: + pass + else: + n = 200 + start = time.time() + while n > 0 and time.time() - start < 3.0: + r, w, e = select.select([conn], [], [], 0.1) + if r: + n -= 1 + data = conn.recv(10) + # keep everything except for the newline terminator + buf.write(data.replace(b"\n", b"")) + if b"\n" in data: + break + time.sleep(0.01) + + conn.close() + finally: + serv.close() + evt.set() + + +def bind_unix_socket(sock, addr): # pragma: no cover + """Bind a unix socket, raising SkipTest if PermissionError is raised.""" + assert sock.family == socket.AF_UNIX + try: + sock.bind(addr) + except PermissionError: + sock.close() + raise unittest.SkipTest("cannot bind AF_UNIX sockets") + + +def bind_af_aware(sock, addr): + """Helper function to bind a socket according to its family.""" + if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX: + # Make sure the path doesn't exist. + unlink(addr) + bind_unix_socket(sock, addr) + else: + sock.bind(addr) + + +if sys.platform.startswith("win"): # pragma: no cover + + def _waitfor(func, pathname, waitall=False): + # Perform the operation + func(pathname) + # Now setup the wait loop + if waitall: + dirname = pathname + else: + dirname, name = os.path.split(pathname) + dirname = dirname or "." + # Check for `pathname` to be removed from the filesystem. + # The exponential backoff of the timeout amounts to a total + # of ~1 second after which the deletion is probably an error + # anyway. + # Testing on an i7@4.3GHz shows that usually only 1 iteration is + # required when contention occurs. + timeout = 0.001 + while timeout < 1.0: + # Note we are only testing for the existence of the file(s) in + # the contents of the directory regardless of any security or + # access rights. If we have made it this far, we have sufficient + # permissions to do that much using Python's equivalent of the + # Windows API FindFirstFile. + # Other Windows APIs can fail or give incorrect results when + # dealing with files that are pending deletion. + L = os.listdir(dirname) + if not (L if waitall else name in L): + return + # Increase the timeout and try again + time.sleep(timeout) + timeout *= 2 + warnings.warn( + "tests may fail, delete still pending for " + pathname, + RuntimeWarning, + stacklevel=4, + ) + + def _unlink(filename): + _waitfor(os.unlink, filename) + + +else: + _unlink = os.unlink + + +def unlink(filename): + try: + _unlink(filename) + except OSError: + pass + + +def _is_ipv6_enabled(): # pragma: no cover + """Check whether IPv6 is enabled on this host.""" + if compat.HAS_IPV6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(("::1", 0)) + return True + except socket.error: + pass + finally: + if sock: + sock.close() + return False + + +IPV6_ENABLED = _is_ipv6_enabled() + + +class HelperFunctionTests(unittest.TestCase): + def test_readwriteexc(self): + # Check exception handling behavior of read, write and _exception + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore read/write/_exception calls + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.read, tr1) + self.assertRaises(asyncore.ExitNow, asyncore.write, tr1) + self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + asyncore.read(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore.write(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore._exception(tr2) + self.assertEqual(tr2.error_handled, True) + + # asyncore.readwrite uses constants in the select module that + # are not present in Windows systems (see this thread: + # http://mail.python.org/pipermail/python-list/2001-October/109973.html) + # These constants should be present as long as poll is available + + @unittest.skipUnless(hasattr(select, "poll"), "select.poll required") + def test_readwrite(self): + # Check that correct methods are called by readwrite() + + attributes = ("read", "expt", "write", "closed", "error_handled") + + expected = ( + (select.POLLIN, "read"), + (select.POLLPRI, "expt"), + (select.POLLOUT, "write"), + (select.POLLERR, "closed"), + (select.POLLHUP, "closed"), + (select.POLLNVAL, "closed"), + ) + + class testobj: + def __init__(self): + self.read = False + self.write = False + self.closed = False + self.expt = False + self.error_handled = False + + def handle_read_event(self): + self.read = True + + def handle_write_event(self): + self.write = True + + def handle_close(self): + self.closed = True + + def handle_expt_event(self): + self.expt = True + + # def handle_error(self): + # self.error_handled = True + + for flag, expectedattr in expected: + tobj = testobj() + self.assertEqual(getattr(tobj, expectedattr), False) + asyncore.readwrite(tobj, flag) + + # Only the attribute modified by the routine we expect to be + # called should be True. + for attr in attributes: + self.assertEqual(getattr(tobj, attr), attr == expectedattr) + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore readwrite call + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + self.assertEqual(tr2.error_handled, False) + asyncore.readwrite(tr2, flag) + self.assertEqual(tr2.error_handled, True) + + def test_closeall(self): + self.closeall_check(False) + + def test_closeall_default(self): + self.closeall_check(True) + + def closeall_check(self, usedefault): + # Check that close_all() closes everything in a given map + + l = [] + testmap = {} + for i in range(10): + c = dummychannel() + l.append(c) + self.assertEqual(c.socket.closed, False) + testmap[i] = c + + if usedefault: + socketmap = asyncore.socket_map + try: + asyncore.socket_map = testmap + asyncore.close_all() + finally: + testmap, asyncore.socket_map = asyncore.socket_map, socketmap + else: + asyncore.close_all(testmap) + + self.assertEqual(len(testmap), 0) + + for c in l: + self.assertEqual(c.socket.closed, True) + + def test_compact_traceback(self): + try: + raise Exception("I don't like spam!") + except: + real_t, real_v, real_tb = sys.exc_info() + r = asyncore.compact_traceback() + + (f, function, line), t, v, info = r + self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py") + self.assertEqual(function, "test_compact_traceback") + self.assertEqual(t, real_t) + self.assertEqual(v, real_v) + self.assertEqual(info, "[%s|%s|%s]" % (f, function, line)) + + +class DispatcherTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + def test_basic(self): + d = asyncore.dispatcher() + self.assertEqual(d.readable(), True) + self.assertEqual(d.writable(), True) + + def test_repr(self): + d = asyncore.dispatcher() + self.assertEqual(repr(d), "" % id(d)) + + def test_log_info(self): + import logging + + inst = asyncore.dispatcher(map={}) + logger = DummyLogger() + inst.logger = logger + inst.log_info("message", "warning") + self.assertEqual(logger.messages, [(logging.WARN, "message")]) + + def test_log(self): + import logging + + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + inst.log("message") + self.assertEqual(logger.messages, [(logging.DEBUG, "message")]) + + def test_unhandled(self): + import logging + + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + + inst.handle_expt() + inst.handle_read() + inst.handle_write() + inst.handle_connect() + + expected = [ + (logging.WARN, "unhandled incoming priority event"), + (logging.WARN, "unhandled read event"), + (logging.WARN, "unhandled write event"), + (logging.WARN, "unhandled connect event"), + ] + self.assertEqual(logger.messages, expected) + + def test_strerror(self): + # refers to bug #8573 + err = asyncore._strerror(errno.EPERM) + if hasattr(os, "strerror"): + self.assertEqual(err, os.strerror(errno.EPERM)) + err = asyncore._strerror(-1) + self.assertTrue(err != "") + + +class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover + def readable(self): + return False + + def handle_connect(self): + pass + + +class DispatcherWithSendTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + @reap_threads + def test_send(self): + evt = threading.Event() + sock = socket.socket() + sock.settimeout(3) + port = bind_port(sock) + + cap = BytesIO() + args = (evt, cap, sock) + t = threading.Thread(target=capture_server, args=args) + t.start() + try: + # wait a little longer for the server to initialize (it sometimes + # refuses connections on slow machines without this wait) + time.sleep(0.2) + + data = b"Suppose there isn't a 16-ton weight?" + d = dispatcherwithsend_noread() + d.create_socket() + d.connect((HOST, port)) + + # give time for socket to connect + time.sleep(0.1) + + d.send(data) + d.send(data) + d.send(b"\n") + + n = 1000 + while d.out_buffer and n > 0: # pragma: no cover + asyncore.poll() + n -= 1 + + evt.wait() + + self.assertEqual(cap.getvalue(), data * 2) + finally: + join_thread(t, timeout=TIMEOUT) + + +@unittest.skipUnless( + hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required" +) +class FileWrapperTest(unittest.TestCase): + def setUp(self): + self.d = b"It's not dead, it's sleeping!" + with open(TESTFN, "wb") as file: + file.write(self.d) + + def tearDown(self): + unlink(TESTFN) + + def test_recv(self): + fd = os.open(TESTFN, os.O_RDONLY) + w = asyncore.file_wrapper(fd) + os.close(fd) + + self.assertNotEqual(w.fd, fd) + self.assertNotEqual(w.fileno(), fd) + self.assertEqual(w.recv(13), b"It's not dead") + self.assertEqual(w.read(6), b", it's") + w.close() + self.assertRaises(OSError, w.read, 1) + + def test_send(self): + d1 = b"Come again?" + d2 = b"I want to buy some cheese." + fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND) + w = asyncore.file_wrapper(fd) + os.close(fd) + + w.write(d1) + w.send(d2) + w.close() + with open(TESTFN, "rb") as file: + self.assertEqual(file.read(), self.d + d1 + d2) + + @unittest.skipUnless( + hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required" + ) + def test_dispatcher(self): + fd = os.open(TESTFN, os.O_RDONLY) + data = [] + + class FileDispatcher(asyncore.file_dispatcher): + def handle_read(self): + data.append(self.recv(29)) + + FileDispatcher(fd) + os.close(fd) + asyncore.loop(timeout=0.01, use_poll=True, count=2) + self.assertEqual(b"".join(data), self.d) + + def test_resource_warning(self): + # Issue #11453 + got_warning = False + while got_warning is False: + # we try until we get the outcome we want because this + # test is not deterministic (gc_collect() may not + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + + os.close(fd) + + try: + with check_warnings(("", compat.ResourceWarning)): + f = None + gc_collect() + except AssertionError: # pragma: no cover + pass + else: + got_warning = True + + def test_close_twice(self): + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + os.close(fd) + + os.close(f.fd) # file_wrapper dupped fd + with self.assertRaises(OSError): + f.close() + + self.assertEqual(f.fd, -1) + # calling close twice should not fail + f.close() + + +class BaseTestHandler(asyncore.dispatcher): # pragma: no cover + def __init__(self, sock=None): + asyncore.dispatcher.__init__(self, sock) + self.flag = False + + def handle_accept(self): + raise Exception("handle_accept not supposed to be called") + + def handle_accepted(self): + raise Exception("handle_accepted not supposed to be called") + + def handle_connect(self): + raise Exception("handle_connect not supposed to be called") + + def handle_expt(self): + raise Exception("handle_expt not supposed to be called") + + def handle_close(self): + raise Exception("handle_close not supposed to be called") + + def handle_error(self): + raise + + +class BaseServer(asyncore.dispatcher): + """A server which listens on an address and dispatches the + connection to a handler. + """ + + def __init__(self, family, addr, handler=BaseTestHandler): + asyncore.dispatcher.__init__(self) + self.create_socket(family) + self.set_reuse_addr() + bind_af_aware(self.socket, addr) + self.listen(5) + self.handler = handler + + @property + def address(self): + return self.socket.getsockname() + + def handle_accepted(self, sock, addr): + self.handler(sock) + + def handle_error(self): # pragma: no cover + raise + + +class BaseClient(BaseTestHandler): + def __init__(self, family, address): + BaseTestHandler.__init__(self) + self.create_socket(family) + self.connect(address) + + def handle_connect(self): + pass + + +class BaseTestAPI: + def tearDown(self): + asyncore.close_all(ignore_all=True) + + def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover + timeout = float(timeout) / 100 + count = 100 + while asyncore.socket_map and count > 0: + asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) + if instance.flag: + return + count -= 1 + time.sleep(timeout) + self.fail("flag not set") + + def test_handle_connect(self): + # make sure handle_connect is called on connect() + + class TestClient(BaseClient): + def handle_connect(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_accept(self): + # make sure handle_accept() is called when a client connects + + class TestListener(BaseTestHandler): + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_accepted(self): + # make sure handle_accepted() is called when a client connects + + class TestListener(BaseTestHandler): + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + asyncore.dispatcher.handle_accept(self) + + def handle_accepted(self, sock, addr): + sock.close() + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_read(self): + # make sure handle_read is called on data received + + class TestClient(BaseClient): + def handle_read(self): + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.send(b"x" * 1024) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_write(self): + # make sure handle_write is called + + class TestClient(BaseClient): + def handle_write(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close(self): + # make sure handle_close is called when the other end closes + # the connection + + class TestClient(BaseClient): + def handle_read(self): + # in order to make handle_close be called we are supposed + # to make at least one recv() call + self.recv(1024) + + def handle_close(self): + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.close() + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close_after_conn_broken(self): + # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and + # #11265). + + data = b"\0" * 128 + + class TestClient(BaseClient): + def handle_write(self): + self.send(data) + + def handle_close(self): + self.flag = True + self.close() + + def handle_expt(self): # pragma: no cover + # needs to exist for MacOS testing + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def handle_read(self): + self.recv(len(data)) + self.close() + + def writable(self): + return False + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + @unittest.skipIf( + sys.platform.startswith("sunos"), "OOB support is broken on Solaris" + ) + def test_handle_expt(self): + # Make sure handle_expt is called on OOB data received. + # Note: this might fail on some platforms as OOB data is + # tenuously supported and rarely used. + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + if sys.platform == "darwin" and self.use_poll: # pragma: no cover + self.skipTest("poll may fail on macOS; see issue #28087") + + class TestClient(BaseClient): + def handle_expt(self): + self.socket.recv(1024, socket.MSG_OOB) + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.socket.send(compat.tobytes(chr(244)), socket.MSG_OOB) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_error(self): + class TestClient(BaseClient): + def handle_write(self): + 1.0 / 0 + + def handle_error(self): + self.flag = True + try: + raise + except ZeroDivisionError: + pass + else: # pragma: no cover + raise Exception("exception not raised") + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_connection_attributes(self): + server = BaseServer(self.family, self.addr) + client = BaseClient(self.family, server.address) + + # we start disconnected + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + # this can't be taken for granted across all platforms + # self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # execute some loops so that client connects to server + asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertTrue(client.connected) + self.assertFalse(client.accepting) + + # disconnect the client + client.close() + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # stop serving + server.close() + self.assertFalse(server.connected) + self.assertFalse(server.accepting) + + def test_create_socket(self): + s = asyncore.dispatcher() + s.create_socket(self.family) + # self.assertEqual(s.socket.type, socket.SOCK_STREAM) + self.assertEqual(s.socket.family, self.family) + self.assertEqual(s.socket.gettimeout(), 0) + # self.assertFalse(s.socket.get_inheritable()) + + def test_bind(self): + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + s1 = asyncore.dispatcher() + s1.create_socket(self.family) + s1.bind(self.addr) + s1.listen(5) + port = s1.socket.getsockname()[1] + + s2 = asyncore.dispatcher() + s2.create_socket(self.family) + # EADDRINUSE indicates the socket was correctly bound + self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) + + def test_set_reuse_addr(self): # pragma: no cover + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + with closewrapper(socket.socket(self.family)) as sock: + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except OSError: + unittest.skip("SO_REUSEADDR not supported on this platform") + else: + # if SO_REUSEADDR succeeded for sock we expect asyncore + # to do the same + s = asyncore.dispatcher(socket.socket(self.family)) + self.assertFalse( + s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) + ) + s.socket.close() + s.create_socket(self.family) + s.set_reuse_addr() + self.assertTrue( + s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) + ) + + @reap_threads + def test_quick_connect(self): # pragma: no cover + # see: http://bugs.python.org/issue10340 + if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())): + self.skipTest("test specific to AF_INET and AF_INET6") + + server = BaseServer(self.family, self.addr) + # run the thread 500 ms: the socket should be connected in 200 ms + t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5)) + t.start() + try: + sock = socket.socket(self.family, socket.SOCK_STREAM) + with closewrapper(sock) as s: + s.settimeout(0.2) + s.setsockopt( + socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) + ) + + try: + s.connect(server.address) + except OSError: + pass + finally: + join_thread(t, timeout=TIMEOUT) + + +class TestAPI_UseIPv4Sockets(BaseTestAPI): + family = socket.AF_INET + addr = (HOST, 0) + + +@unittest.skipUnless(IPV6_ENABLED, "IPv6 support required") +class TestAPI_UseIPv6Sockets(BaseTestAPI): + family = socket.AF_INET6 + addr = (HOSTv6, 0) + + +@unittest.skipUnless(HAS_UNIX_SOCKETS, "Unix sockets required") +class TestAPI_UseUnixSockets(BaseTestAPI): + if HAS_UNIX_SOCKETS: + family = socket.AF_UNIX + addr = TESTFN + + def tearDown(self): + unlink(self.addr) + BaseTestAPI.tearDown(self) + + +class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = False + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = True + + +class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = False + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = True + + +class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase): + use_poll = False + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase): + use_poll = True + + +class Test__strerror(unittest.TestCase): + def _callFUT(self, err): + from waitress.wasyncore import _strerror + + return _strerror(err) + + def test_gardenpath(self): + self.assertEqual(self._callFUT(1), "Operation not permitted") + + def test_unknown(self): + self.assertEqual(self._callFUT("wut"), "Unknown error wut") + + +class Test_read(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import read + + return read(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + +class Test_write(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import write + + return write(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertTrue(inst.error_handled) + + +class Test__exception(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import _exception + + return _exception(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertTrue(inst.error_handled) + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class Test_readwrite(unittest.TestCase): + def _callFUT(self, obj, flags): + from waitress.wasyncore import readwrite + + return readwrite(obj, flags) + + def test_handle_read_event(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_handle_write_event(self): + flags = 0 + flags |= select.POLLOUT + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.write_event_handled) + + def test_handle_expt_event(self): + flags = 0 + flags |= select.POLLPRI + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.expt_event_handled) + + def test_handle_close(self): + flags = 0 + flags |= select.POLLHUP + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.close_handled) + + def test_socketerror_not_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY")) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + def test_socketerror_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET")) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.close_handled) + + def test_exception_in_reraised(self): + from waitress import wasyncore + + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(wasyncore.ExitNow) + self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_exception_not_in_reraised(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(ValueError) + self._callFUT(inst, flags) + self.assertTrue(inst.error_handled) + + +class Test_poll(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll + + return poll(timeout, map) + + def test_nothing_writable_nothing_readable_but_map_not_empty(self): + # i read the mock.patch docs. nerp. + dummy_time = DummyTime() + map = {0: DummyDispatcher()} + try: + from waitress import wasyncore + + old_time = wasyncore.time + wasyncore.time = dummy_time + result = self._callFUT(map=map) + finally: + wasyncore.time = old_time + self.assertEqual(result, None) + self.assertEqual(dummy_time.sleepvals, [0.0]) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EINTR)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + result = self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(result, None) + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EBADF)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + +class Test_poll2(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll2 + + return poll2(timeout, map) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EINTR)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EBADF)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + +class Test_dispatcher(unittest.TestCase): + def _makeOne(self, sock=None, map=None): + from waitress.wasyncore import dispatcher + + return dispatcher(sock=sock, map=map) + + def test_unexpected_getpeername_exc(self): + sock = dummysocket() + + def getpeername(): + raise socket.error(errno.EBADF) + + map = {} + sock.getpeername = getpeername + self.assertRaises(socket.error, self._makeOne, sock=sock, map=map) + self.assertEqual(map, {}) + + def test___repr__accepting(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.accepting = True + inst.addr = ("localhost", 8080) + result = repr(inst) + expected = "= 1) - - def test_ipv6_no_port(self): # pragma: nocover - if not self._hasIPv6(): - return - - inst = self._makeOne(listen="[::1]") - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [("::1", 8080)]) - - def test_bad_port(self): - self.assertRaises(ValueError, self._makeOne, listen="127.0.0.1:test") - - def test_service_port(self): - if WIN and PY2: # pragma: no cover - # On Windows and Python 2 this is broken, so we raise a ValueError - self.assertRaises( - ValueError, self._makeOne, listen="127.0.0.1:http", - ) - return - - inst = self._makeOne(listen="127.0.0.1:http 0.0.0.0:https") - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [("127.0.0.1", 80), ("0.0.0.0", 443)]) - - def test_dont_mix_host_port_listen(self): - self.assertRaises( - ValueError, - self._makeOne, - host="localhost", - port="8080", - listen="127.0.0.1:8080", - ) - - def test_good_sockets(self): - sockets = [ - socket.socket(socket.AF_INET6, socket.SOCK_STREAM), - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - ] - inst = self._makeOne(sockets=sockets) - self.assertEqual(inst.sockets, sockets) - sockets[0].close() - sockets[1].close() - - def test_dont_mix_sockets_and_listen(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - self.assertRaises( - ValueError, self._makeOne, listen="127.0.0.1:8080", sockets=sockets - ) - sockets[0].close() - - def test_dont_mix_sockets_and_host_port(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - self.assertRaises( - ValueError, self._makeOne, host="localhost", port="8080", sockets=sockets - ) - sockets[0].close() - - def test_dont_mix_sockets_and_unix_socket(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - self.assertRaises( - ValueError, self._makeOne, unix_socket="./tmp/test", sockets=sockets - ) - sockets[0].close() - - def test_dont_mix_unix_socket_and_host_port(self): - self.assertRaises( - ValueError, - self._makeOne, - unix_socket="./tmp/test", - host="localhost", - port="8080", - ) - - def test_dont_mix_unix_socket_and_listen(self): - self.assertRaises( - ValueError, self._makeOne, unix_socket="./tmp/test", listen="127.0.0.1:8080" - ) - - def test_dont_use_unsupported_socket_types(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] - self.assertRaises(ValueError, self._makeOne, sockets=sockets) - sockets[0].close() - - def test_dont_mix_forwarded_with_x_forwarded(self): - with self.assertRaises(ValueError) as cm: - self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers={"forwarded", "x-forwarded-for"}, - ) - - self.assertIn("The Forwarded proxy header", str(cm.exception)) - - def test_unknown_trusted_proxy_header(self): - with self.assertRaises(ValueError) as cm: - self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers={"forwarded", "x-forwarded-unknown"}, - ) - - self.assertIn( - "unknown trusted_proxy_headers value (x-forwarded-unknown)", - str(cm.exception), - ) - - def test_trusted_proxy_count_no_trusted_proxy(self): - with self.assertRaises(ValueError) as cm: - self._makeOne(trusted_proxy_count=1) - - self.assertIn("trusted_proxy_count has no meaning", str(cm.exception)) - - def test_trusted_proxy_headers_no_trusted_proxy(self): - with self.assertRaises(ValueError) as cm: - self._makeOne(trusted_proxy_headers={"forwarded"}) - - self.assertIn("trusted_proxy_headers has no meaning", str(cm.exception)) - - def test_trusted_proxy_headers_string_list(self): - inst = self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers="x-forwarded-for x-forwarded-by", - ) - self.assertEqual( - inst.trusted_proxy_headers, {"x-forwarded-for", "x-forwarded-by"} - ) - - def test_trusted_proxy_headers_string_list_newlines(self): - inst = self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers="x-forwarded-for\nx-forwarded-by\nx-forwarded-host", - ) - self.assertEqual( - inst.trusted_proxy_headers, - {"x-forwarded-for", "x-forwarded-by", "x-forwarded-host"}, - ) - - def test_no_trusted_proxy_headers_trusted_proxy(self): - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.simplefilter("always") - self._makeOne(trusted_proxy="localhost") - - self.assertGreaterEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn("Implicitly trusting X-Forwarded-Proto", str(w[0])) - - def test_clear_untrusted_proxy_headers(self): - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.simplefilter("always") - self._makeOne( - trusted_proxy="localhost", trusted_proxy_headers={"x-forwarded-for"} - ) - - self.assertGreaterEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn( - "clear_untrusted_proxy_headers will be set to True", str(w[0]) - ) - - def test_deprecated_send_bytes(self): - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.simplefilter("always") - self._makeOne(send_bytes=1) - - self.assertGreaterEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn("send_bytes", str(w[0])) - - def test_badvar(self): - self.assertRaises(ValueError, self._makeOne, nope=True) - - def test_ipv4_disabled(self): - self.assertRaises( - ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080" - ) - - def test_ipv6_disabled(self): - self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") - - def test_server_header_removable(self): - inst = self._makeOne(ident=None) - self.assertEqual(inst.ident, None) - - inst = self._makeOne(ident="") - self.assertEqual(inst.ident, None) - - inst = self._makeOne(ident="specific_header") - self.assertEqual(inst.ident, "specific_header") - - -class TestCLI(unittest.TestCase): - def parse(self, argv): - from waitress.adjustments import Adjustments - - return Adjustments.parse_args(argv) - - def test_noargs(self): - opts, args = self.parse([]) - self.assertDictEqual(opts, {"call": False, "help": False}) - self.assertSequenceEqual(args, []) - - def test_help(self): - opts, args = self.parse(["--help"]) - self.assertDictEqual(opts, {"call": False, "help": True}) - self.assertSequenceEqual(args, []) - - def test_call(self): - opts, args = self.parse(["--call"]) - self.assertDictEqual(opts, {"call": True, "help": False}) - self.assertSequenceEqual(args, []) - - def test_both(self): - opts, args = self.parse(["--call", "--help"]) - self.assertDictEqual(opts, {"call": True, "help": True}) - self.assertSequenceEqual(args, []) - - def test_positive_boolean(self): - opts, args = self.parse(["--expose-tracebacks"]) - self.assertDictContainsSubset({"expose_tracebacks": "true"}, opts) - self.assertSequenceEqual(args, []) - - def test_negative_boolean(self): - opts, args = self.parse(["--no-expose-tracebacks"]) - self.assertDictContainsSubset({"expose_tracebacks": "false"}, opts) - self.assertSequenceEqual(args, []) - - def test_cast_params(self): - opts, args = self.parse( - ["--host=localhost", "--port=80", "--unix-socket-perms=777"] - ) - self.assertDictContainsSubset( - {"host": "localhost", "port": "80", "unix_socket_perms": "777",}, opts - ) - self.assertSequenceEqual(args, []) - - def test_listen_params(self): - opts, args = self.parse(["--listen=test:80",]) - - self.assertDictContainsSubset({"listen": " test:80"}, opts) - self.assertSequenceEqual(args, []) - - def test_multiple_listen_params(self): - opts, args = self.parse(["--listen=test:80", "--listen=test:8080",]) - - self.assertDictContainsSubset({"listen": " test:80 test:8080"}, opts) - self.assertSequenceEqual(args, []) - - def test_bad_param(self): - import getopt - - self.assertRaises(getopt.GetoptError, self.parse, ["--no-host"]) - - -if hasattr(socket, "AF_UNIX"): - - class TestUnixSocket(unittest.TestCase): - def _makeOne(self, **kw): - from waitress.adjustments import Adjustments - - return Adjustments(**kw) - - def test_dont_mix_internet_and_unix_sockets(self): - sockets = [ - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - ] - self.assertRaises(ValueError, self._makeOne, sockets=sockets) - sockets[0].close() - sockets[1].close() diff --git a/waitress/tests/test_buffers.py b/waitress/tests/test_buffers.py deleted file mode 100644 index a1330ac..0000000 --- a/waitress/tests/test_buffers.py +++ /dev/null @@ -1,523 +0,0 @@ -import unittest -import io - - -class TestFileBasedBuffer(unittest.TestCase): - def _makeOne(self, file=None, from_buffer=None): - from waitress.buffers import FileBasedBuffer - - buf = FileBasedBuffer(file, from_buffer=from_buffer) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test_ctor_from_buffer_None(self): - inst = self._makeOne("file") - self.assertEqual(inst.file, "file") - - def test_ctor_from_buffer(self): - from_buffer = io.BytesIO(b"data") - from_buffer.getfile = lambda *x: from_buffer - f = io.BytesIO() - inst = self._makeOne(f, from_buffer) - self.assertEqual(inst.file, f) - del from_buffer.getfile - self.assertEqual(inst.remain, 4) - from_buffer.close() - - def test___len__(self): - inst = self._makeOne() - inst.remain = 10 - self.assertEqual(len(inst), 10) - - def test___nonzero__(self): - inst = self._makeOne() - inst.remain = 10 - self.assertEqual(bool(inst), True) - inst.remain = 0 - self.assertEqual(bool(inst), True) - - def test_append(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - inst.append(b"data2") - self.assertEqual(f.getvalue(), b"datadata2") - self.assertEqual(inst.remain, 5) - - def test_get_skip_true(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - result = inst.get(100, skip=True) - self.assertEqual(result, b"data") - self.assertEqual(inst.remain, -4) - - def test_get_skip_false(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - result = inst.get(100, skip=False) - self.assertEqual(result, b"data") - self.assertEqual(inst.remain, 0) - - def test_get_skip_bytes_less_than_zero(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - result = inst.get(-1, skip=False) - self.assertEqual(result, b"data") - self.assertEqual(inst.remain, 0) - - def test_skip_remain_gt_bytes(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - inst.remain = 1 - inst.skip(1) - self.assertEqual(inst.remain, 0) - - def test_skip_remain_lt_bytes(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - inst.remain = 1 - self.assertRaises(ValueError, inst.skip, 2) - - def test_newfile(self): - inst = self._makeOne() - self.assertRaises(NotImplementedError, inst.newfile) - - def test_prune_remain_notzero(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - inst.remain = 1 - nf = io.BytesIO() - inst.newfile = lambda *x: nf - inst.prune() - self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b"d") - - def test_prune_remain_zero_tell_notzero(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - nf = io.BytesIO(b"d") - inst.newfile = lambda *x: nf - inst.remain = 0 - inst.prune() - self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b"d") - - def test_prune_remain_zero_tell_zero(self): - f = io.BytesIO() - inst = self._makeOne(f) - inst.remain = 0 - inst.prune() - self.assertTrue(inst.file is f) - - def test_close(self): - f = io.BytesIO() - inst = self._makeOne(f) - inst.close() - self.assertTrue(f.closed) - self.buffers_to_close.remove(inst) - - -class TestTempfileBasedBuffer(unittest.TestCase): - def _makeOne(self, from_buffer=None): - from waitress.buffers import TempfileBasedBuffer - - buf = TempfileBasedBuffer(from_buffer=from_buffer) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test_newfile(self): - inst = self._makeOne() - r = inst.newfile() - self.assertTrue(hasattr(r, "fileno")) # file - r.close() - - -class TestBytesIOBasedBuffer(unittest.TestCase): - def _makeOne(self, from_buffer=None): - from waitress.buffers import BytesIOBasedBuffer - - return BytesIOBasedBuffer(from_buffer=from_buffer) - - def test_ctor_from_buffer_not_None(self): - f = io.BytesIO() - f.getfile = lambda *x: f - inst = self._makeOne(f) - self.assertTrue(hasattr(inst.file, "read")) - - def test_ctor_from_buffer_None(self): - inst = self._makeOne() - self.assertTrue(hasattr(inst.file, "read")) - - def test_newfile(self): - inst = self._makeOne() - r = inst.newfile() - self.assertTrue(hasattr(r, "read")) - - -class TestReadOnlyFileBasedBuffer(unittest.TestCase): - def _makeOne(self, file, block_size=8192): - from waitress.buffers import ReadOnlyFileBasedBuffer - - buf = ReadOnlyFileBasedBuffer(file, block_size) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test_prepare_not_seekable(self): - f = KindaFilelike(b"abc") - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, False) - self.assertEqual(inst.remain, 0) - - def test_prepare_not_seekable_closeable(self): - f = KindaFilelike(b"abc", close=1) - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, False) - self.assertEqual(inst.remain, 0) - self.assertTrue(hasattr(inst, "close")) - - def test_prepare_seekable_closeable(self): - f = Filelike(b"abc", close=1, tellresults=[0, 10]) - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, 10) - self.assertEqual(inst.remain, 10) - self.assertEqual(inst.file.seeked, 0) - self.assertTrue(hasattr(inst, "close")) - - def test_get_numbytes_neg_one(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(-1) - self.assertEqual(result, b"ab") - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) - - def test_get_numbytes_gt_remain(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(3) - self.assertEqual(result, b"ab") - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) - - def test_get_numbytes_lt_remain(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(1) - self.assertEqual(result, b"a") - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) - - def test_get_numbytes_gt_remain_withskip(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(3, skip=True) - self.assertEqual(result, b"ab") - self.assertEqual(inst.remain, 0) - self.assertEqual(f.tell(), 2) - - def test_get_numbytes_lt_remain_withskip(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(1, skip=True) - self.assertEqual(result, b"a") - self.assertEqual(inst.remain, 1) - self.assertEqual(f.tell(), 1) - - def test___iter__(self): - data = b"a" * 10000 - f = io.BytesIO(data) - inst = self._makeOne(f) - r = b"" - for val in inst: - r += val - self.assertEqual(r, data) - - def test_append(self): - inst = self._makeOne(None) - self.assertRaises(NotImplementedError, inst.append, "a") - - -class TestOverflowableBuffer(unittest.TestCase): - def _makeOne(self, overflow=10): - from waitress.buffers import OverflowableBuffer - - buf = OverflowableBuffer(overflow) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test___len__buf_is_None(self): - inst = self._makeOne() - self.assertEqual(len(inst), 0) - - def test___len__buf_is_not_None(self): - inst = self._makeOne() - inst.buf = b"abc" - self.assertEqual(len(inst), 3) - self.buffers_to_close.remove(inst) - - def test___nonzero__(self): - inst = self._makeOne() - inst.buf = b"abc" - self.assertEqual(bool(inst), True) - inst.buf = b"" - self.assertEqual(bool(inst), False) - self.buffers_to_close.remove(inst) - - def test___nonzero___on_int_overflow_buffer(self): - inst = self._makeOne() - - class int_overflow_buf(bytes): - def __len__(self): - # maxint + 1 - return 0x7FFFFFFFFFFFFFFF + 1 - - inst.buf = int_overflow_buf() - self.assertEqual(bool(inst), True) - inst.buf = b"" - self.assertEqual(bool(inst), False) - self.buffers_to_close.remove(inst) - - def test__create_buffer_large(self): - from waitress.buffers import TempfileBasedBuffer - - inst = self._makeOne() - inst.strbuf = b"x" * 11 - inst._create_buffer() - self.assertEqual(inst.buf.__class__, TempfileBasedBuffer) - self.assertEqual(inst.buf.get(100), b"x" * 11) - self.assertEqual(inst.strbuf, b"") - - def test__create_buffer_small(self): - from waitress.buffers import BytesIOBasedBuffer - - inst = self._makeOne() - inst.strbuf = b"x" * 5 - inst._create_buffer() - self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer) - self.assertEqual(inst.buf.get(100), b"x" * 5) - self.assertEqual(inst.strbuf, b"") - - def test_append_with_len_more_than_max_int(self): - from waitress.compat import MAXINT - - inst = self._makeOne() - inst.overflowed = True - buf = DummyBuffer(length=MAXINT) - inst.buf = buf - result = inst.append(b"x") - # we don't want this to throw an OverflowError on Python 2 (see - # https://github.com/Pylons/waitress/issues/47) - self.assertEqual(result, None) - self.buffers_to_close.remove(inst) - - def test_append_buf_None_not_longer_than_srtbuf_limit(self): - inst = self._makeOne() - inst.strbuf = b"x" * 5 - inst.append(b"hello") - self.assertEqual(inst.strbuf, b"xxxxxhello") - - def test_append_buf_None_longer_than_strbuf_limit(self): - inst = self._makeOne(10000) - inst.strbuf = b"x" * 8192 - inst.append(b"hello") - self.assertEqual(inst.strbuf, b"") - self.assertEqual(len(inst.buf), 8197) - - def test_append_overflow(self): - inst = self._makeOne(10) - inst.strbuf = b"x" * 8192 - inst.append(b"hello") - self.assertEqual(inst.strbuf, b"") - self.assertEqual(len(inst.buf), 8197) - - def test_append_sz_gt_overflow(self): - from waitress.buffers import BytesIOBasedBuffer - - f = io.BytesIO(b"data") - inst = self._makeOne(f) - buf = BytesIOBasedBuffer() - inst.buf = buf - inst.overflow = 2 - inst.append(b"data2") - self.assertEqual(f.getvalue(), b"data") - self.assertTrue(inst.overflowed) - self.assertNotEqual(inst.buf, buf) - - def test_get_buf_None_skip_False(self): - inst = self._makeOne() - inst.strbuf = b"x" * 5 - r = inst.get(5) - self.assertEqual(r, b"xxxxx") - - def test_get_buf_None_skip_True(self): - inst = self._makeOne() - inst.strbuf = b"x" * 5 - r = inst.get(5, skip=True) - self.assertFalse(inst.buf is None) - self.assertEqual(r, b"xxxxx") - - def test_skip_buf_None(self): - inst = self._makeOne() - inst.strbuf = b"data" - inst.skip(4) - self.assertEqual(inst.strbuf, b"") - self.assertNotEqual(inst.buf, None) - - def test_skip_buf_None_allow_prune_True(self): - inst = self._makeOne() - inst.strbuf = b"data" - inst.skip(4, True) - self.assertEqual(inst.strbuf, b"") - self.assertEqual(inst.buf, None) - - def test_prune_buf_None(self): - inst = self._makeOne() - inst.prune() - self.assertEqual(inst.strbuf, b"") - - def test_prune_with_buf(self): - inst = self._makeOne() - - class Buf(object): - def prune(self): - self.pruned = True - - inst.buf = Buf() - inst.prune() - self.assertEqual(inst.buf.pruned, True) - self.buffers_to_close.remove(inst) - - def test_prune_with_buf_overflow(self): - inst = self._makeOne() - - class DummyBuffer(io.BytesIO): - def getfile(self): - return self - - def prune(self): - return True - - def __len__(self): - return 5 - - def close(self): - pass - - buf = DummyBuffer(b"data") - inst.buf = buf - inst.overflowed = True - inst.overflow = 10 - inst.prune() - self.assertNotEqual(inst.buf, buf) - - def test_prune_with_buflen_more_than_max_int(self): - from waitress.compat import MAXINT - - inst = self._makeOne() - inst.overflowed = True - buf = DummyBuffer(length=MAXINT + 1) - inst.buf = buf - result = inst.prune() - # we don't want this to throw an OverflowError on Python 2 (see - # https://github.com/Pylons/waitress/issues/47) - self.assertEqual(result, None) - - def test_getfile_buf_None(self): - inst = self._makeOne() - f = inst.getfile() - self.assertTrue(hasattr(f, "read")) - - def test_getfile_buf_not_None(self): - inst = self._makeOne() - buf = io.BytesIO() - buf.getfile = lambda *x: buf - inst.buf = buf - f = inst.getfile() - self.assertEqual(f, buf) - - def test_close_nobuf(self): - inst = self._makeOne() - inst.buf = None - self.assertEqual(inst.close(), None) # doesnt raise - self.buffers_to_close.remove(inst) - - def test_close_withbuf(self): - class Buffer(object): - def close(self): - self.closed = True - - buf = Buffer() - inst = self._makeOne() - inst.buf = buf - inst.close() - self.assertTrue(buf.closed) - self.buffers_to_close.remove(inst) - - -class KindaFilelike(object): - def __init__(self, bytes, close=None, tellresults=None): - self.bytes = bytes - self.tellresults = tellresults - if close is not None: - self.close = lambda: close - - -class Filelike(KindaFilelike): - def seek(self, v, whence=0): - self.seeked = v - - def tell(self): - v = self.tellresults.pop(0) - return v - - -class DummyBuffer(object): - def __init__(self, length=0): - self.length = length - - def __len__(self): - return self.length - - def append(self, s): - self.length = self.length + len(s) - - def prune(self): - pass - - def close(self): - pass diff --git a/waitress/tests/test_channel.py b/waitress/tests/test_channel.py deleted file mode 100644 index 14ef5a0..0000000 --- a/waitress/tests/test_channel.py +++ /dev/null @@ -1,882 +0,0 @@ -import unittest -import io - - -class TestHTTPChannel(unittest.TestCase): - def _makeOne(self, sock, addr, adj, map=None): - from waitress.channel import HTTPChannel - - server = DummyServer() - return HTTPChannel(server, sock, addr, adj=adj, map=map) - - def _makeOneWithMap(self, adj=None): - if adj is None: - adj = DummyAdjustments() - sock = DummySock() - map = {} - inst = self._makeOne(sock, "127.0.0.1", adj, map=map) - inst.outbuf_lock = DummyLock() - return inst, sock, map - - def test_ctor(self): - inst, _, map = self._makeOneWithMap() - self.assertEqual(inst.addr, "127.0.0.1") - self.assertEqual(inst.sendbuf_len, 2048) - self.assertEqual(map[100], inst) - - def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self): - from waitress.compat import MAXINT - - inst, _, map = self._makeOneWithMap() - - class DummyBuffer(object): - chunks = [] - - def append(self, data): - self.chunks.append(data) - - class DummyData(object): - def __len__(self): - return MAXINT - - inst.total_outbufs_len = 1 - inst.outbufs = [DummyBuffer()] - inst.write_soon(DummyData()) - # we are testing that this method does not raise an OverflowError - # (see https://github.com/Pylons/waitress/issues/47) - self.assertEqual(inst.total_outbufs_len, MAXINT + 1) - - def test_writable_something_in_outbuf(self): - inst, sock, map = self._makeOneWithMap() - inst.total_outbufs_len = 3 - self.assertTrue(inst.writable()) - - def test_writable_nothing_in_outbuf(self): - inst, sock, map = self._makeOneWithMap() - self.assertFalse(inst.writable()) - - def test_writable_nothing_in_outbuf_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.will_close = True - self.assertTrue(inst.writable()) - - def test_handle_write_not_connected(self): - inst, sock, map = self._makeOneWithMap() - inst.connected = False - self.assertFalse(inst.handle_write()) - - def test_handle_write_with_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - - def test_handle_write_no_request_with_outbuf(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b"abc") - - def test_handle_write_outbuf_raises_socketerror(self): - import socket - - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - outbuf = DummyBuffer(b"abc", socket.error) - inst.outbufs = [outbuf] - inst.total_outbufs_len = len(outbuf) - inst.last_activity = 0 - inst.logger = DummyLogger() - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b"") - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(outbuf.closed) - - def test_handle_write_outbuf_raises_othererror(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - outbuf = DummyBuffer(b"abc", IOError) - inst.outbufs = [outbuf] - inst.total_outbufs_len = len(outbuf) - inst.last_activity = 0 - inst.logger = DummyLogger() - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b"") - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(outbuf.closed) - - def test_handle_write_no_requests_no_outbuf_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - outbuf = DummyBuffer(b"") - inst.outbufs = [outbuf] - inst.will_close = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) - self.assertEqual(inst.last_activity, 0) - self.assertTrue(outbuf.closed) - - def test_handle_write_no_requests_outbuf_gt_send_bytes(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.adj.send_bytes = 2 - inst.will_close = False - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertEqual(sock.sent, b"abc") - - def test_handle_write_close_when_flushed(self): - inst, sock, map = self._makeOneWithMap() - outbuf = DummyBuffer(b"abc") - inst.outbufs = [outbuf] - inst.total_outbufs_len = len(outbuf) - inst.will_close = False - inst.close_when_flushed = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, True) - self.assertEqual(inst.close_when_flushed, False) - self.assertEqual(sock.sent, b"abc") - self.assertTrue(outbuf.closed) - - def test_readable_no_requests_not_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.will_close = False - self.assertEqual(inst.readable(), True) - - def test_readable_no_requests_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.will_close = True - self.assertEqual(inst.readable(), False) - - def test_readable_with_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = True - self.assertEqual(inst.readable(), False) - - def test_handle_read_no_error(self): - inst, sock, map = self._makeOneWithMap() - inst.will_close = False - inst.recv = lambda *arg: b"abc" - inst.last_activity = 0 - L = [] - inst.received = lambda x: L.append(x) - result = inst.handle_read() - self.assertEqual(result, None) - self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(L, [b"abc"]) - - def test_handle_read_error(self): - import socket - - inst, sock, map = self._makeOneWithMap() - inst.will_close = False - - def recv(b): - raise socket.error - - inst.recv = recv - inst.last_activity = 0 - inst.logger = DummyLogger() - result = inst.handle_read() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - self.assertEqual(len(inst.logger.exceptions), 1) - - def test_write_soon_empty_byte(self): - inst, sock, map = self._makeOneWithMap() - wrote = inst.write_soon(b"") - self.assertEqual(wrote, 0) - self.assertEqual(len(inst.outbufs[0]), 0) - - def test_write_soon_nonempty_byte(self): - inst, sock, map = self._makeOneWithMap() - wrote = inst.write_soon(b"a") - self.assertEqual(wrote, 1) - self.assertEqual(len(inst.outbufs[0]), 1) - - def test_write_soon_filewrapper(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - wrapper = ReadOnlyFileBasedBuffer(f, 8192) - wrapper.prepare() - inst, sock, map = self._makeOneWithMap() - outbufs = inst.outbufs - orig_outbuf = outbufs[0] - wrote = inst.write_soon(wrapper) - self.assertEqual(wrote, 3) - self.assertEqual(len(outbufs), 3) - self.assertEqual(outbufs[0], orig_outbuf) - self.assertEqual(outbufs[1], wrapper) - self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer") - - def test_write_soon_disconnected(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - inst.connected = False - self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) - - def test_write_soon_disconnected_while_over_watermark(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - - def dummy_flush(): - inst.connected = False - - inst._flush_outbufs_below_high_watermark = dummy_flush - self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) - - def test_write_soon_rotates_outbuf_on_overflow(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.outbuf_high_watermark = 3 - inst.current_outbuf_count = 4 - wrote = inst.write_soon(b"xyz") - self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") - - def test_write_soon_waits_on_backpressure(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.outbuf_high_watermark = 3 - inst.total_outbufs_len = 4 - inst.current_outbuf_count = 4 - - class Lock(DummyLock): - def wait(self): - inst.total_outbufs_len = 0 - super(Lock, self).wait() - - inst.outbuf_lock = Lock() - wrote = inst.write_soon(b"xyz") - self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") - self.assertTrue(inst.outbuf_lock.waited) - - def test_handle_write_notify_after_flush(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.adj.send_bytes = 1 - inst.adj.outbuf_high_watermark = 5 - inst.will_close = False - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertTrue(inst.outbuf_lock.notified) - self.assertEqual(sock.sent, b"abc") - - def test_handle_write_no_notify_after_flush(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.adj.send_bytes = 1 - inst.adj.outbuf_high_watermark = 2 - sock.send = lambda x: False - inst.will_close = False - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertFalse(inst.outbuf_lock.notified) - self.assertEqual(sock.sent, b"") - - def test__flush_some_empty_outbuf(self): - inst, sock, map = self._makeOneWithMap() - result = inst._flush_some() - self.assertEqual(result, False) - - def test__flush_some_full_outbuf_socket_returns_nonzero(self): - inst, sock, map = self._makeOneWithMap() - inst.outbufs[0].append(b"abc") - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - result = inst._flush_some() - self.assertEqual(result, True) - - def test__flush_some_full_outbuf_socket_returns_zero(self): - inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: False - inst.outbufs[0].append(b"abc") - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - result = inst._flush_some() - self.assertEqual(result, False) - - def test_flush_some_multiple_buffers_first_empty(self): - inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: len(x) - buffer = DummyBuffer(b"abc") - inst.outbufs.append(buffer) - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - result = inst._flush_some() - self.assertEqual(result, True) - self.assertEqual(buffer.skipped, 3) - self.assertEqual(inst.outbufs, [buffer]) - - def test_flush_some_multiple_buffers_close_raises(self): - inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: len(x) - buffer = DummyBuffer(b"abc") - inst.outbufs.append(buffer) - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - inst.logger = DummyLogger() - - def doraise(): - raise NotImplementedError - - inst.outbufs[0].close = doraise - result = inst._flush_some() - self.assertEqual(result, True) - self.assertEqual(buffer.skipped, 3) - self.assertEqual(inst.outbufs, [buffer]) - self.assertEqual(len(inst.logger.exceptions), 1) - - def test__flush_some_outbuf_len_gt_sys_maxint(self): - from waitress.compat import MAXINT - - inst, sock, map = self._makeOneWithMap() - - class DummyHugeOutbuffer(object): - def __init__(self): - self.length = MAXINT + 1 - - def __len__(self): - return self.length - - def get(self, numbytes): - self.length = 0 - return b"123" - - buf = DummyHugeOutbuffer() - inst.outbufs = [buf] - inst.send = lambda *arg: 0 - result = inst._flush_some() - # we are testing that _flush_some doesn't raise an OverflowError - # when one of its outbufs has a __len__ that returns gt sys.maxint - self.assertEqual(result, False) - - def test_handle_close(self): - inst, sock, map = self._makeOneWithMap() - inst.handle_close() - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) - - def test_handle_close_outbuf_raises_on_close(self): - inst, sock, map = self._makeOneWithMap() - - def doraise(): - raise NotImplementedError - - inst.outbufs[0].close = doraise - inst.logger = DummyLogger() - inst.handle_close() - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) - self.assertEqual(len(inst.logger.exceptions), 1) - - def test_add_channel(self): - inst, sock, map = self._makeOneWithMap() - fileno = inst._fileno - inst.add_channel(map) - self.assertEqual(map[fileno], inst) - self.assertEqual(inst.server.active_channels[fileno], inst) - - def test_del_channel(self): - inst, sock, map = self._makeOneWithMap() - fileno = inst._fileno - inst.server.active_channels[fileno] = True - inst.del_channel(map) - self.assertEqual(map.get(fileno), None) - self.assertEqual(inst.server.active_channels.get(fileno), None) - - def test_received(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.server.tasks, [inst]) - self.assertTrue(inst.requests) - - def test_received_no_chunk(self): - inst, sock, map = self._makeOneWithMap() - self.assertEqual(inst.received(b""), False) - - def test_received_preq_not_completed(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = False - preq.empty = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.requests, ()) - self.assertEqual(inst.server.tasks, []) - - def test_received_preq_completed_empty(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.empty = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, None) - self.assertEqual(inst.server.tasks, []) - - def test_received_preq_error(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.error = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, None) - self.assertEqual(len(inst.server.tasks), 1) - self.assertTrue(inst.requests) - - def test_received_preq_completed_connection_close(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.empty = True - preq.connection_close = True - inst.received(b"GET / HTTP/1.1\r\n\r\n" + b"a" * 50000) - self.assertEqual(inst.request, None) - self.assertEqual(inst.server.tasks, []) - - def test_received_headers_finished_expect_continue_false(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.expect_continue = False - preq.headers_finished = True - preq.completed = False - preq.empty = False - preq.retval = 1 - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, preq) - self.assertEqual(inst.server.tasks, []) - self.assertEqual(inst.outbufs[0].get(100), b"") - - def test_received_headers_finished_expect_continue_true(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.expect_continue = True - preq.headers_finished = True - preq.completed = False - preq.empty = False - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, preq) - self.assertEqual(inst.server.tasks, []) - self.assertEqual(sock.sent, b"HTTP/1.1 100 Continue\r\n\r\n") - self.assertEqual(inst.sent_continue, True) - self.assertEqual(preq.completed, False) - - def test_received_headers_finished_expect_continue_true_sent_true(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.expect_continue = True - preq.headers_finished = True - preq.completed = False - preq.empty = False - inst.sent_continue = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, preq) - self.assertEqual(inst.server.tasks, []) - self.assertEqual(sock.sent, b"") - self.assertEqual(inst.sent_continue, True) - self.assertEqual(preq.completed, False) - - def test_service_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - - def test_service_with_one_request(self): - inst, sock, map = self._makeOneWithMap() - request = DummyRequest() - inst.task_class = DummyTaskClass() - inst.requests = [request] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(request.serviced) - self.assertTrue(request.closed) - - def test_service_with_one_error_request(self): - inst, sock, map = self._makeOneWithMap() - request = DummyRequest() - request.error = DummyError() - inst.error_task_class = DummyTaskClass() - inst.requests = [request] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(request.serviced) - self.assertTrue(request.closed) - - def test_service_with_multiple_requests(self): - inst, sock, map = self._makeOneWithMap() - request1 = DummyRequest() - request2 = DummyRequest() - inst.task_class = DummyTaskClass() - inst.requests = [request1, request2] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(request1.serviced) - self.assertTrue(request2.serviced) - self.assertTrue(request1.closed) - self.assertTrue(request2.closed) - - def test_service_with_request_raises(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.task_class.wrote_header = False - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertFalse(inst.will_close) - self.assertEqual(inst.error_task_class.serviced, True) - self.assertTrue(request.closed) - - def test_service_with_requests_raises_already_wrote_header(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertTrue(inst.close_when_flushed) - self.assertEqual(inst.error_task_class.serviced, False) - self.assertTrue(request.closed) - - def test_service_with_requests_raises_didnt_write_header_expose_tbs(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = True - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.task_class.wrote_header = False - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertFalse(inst.will_close) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertEqual(inst.error_task_class.serviced, True) - self.assertTrue(request.closed) - - def test_service_with_requests_raises_didnt_write_header(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.task_class.wrote_header = False - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertTrue(inst.close_when_flushed) - self.assertTrue(request.closed) - - def test_service_with_request_raises_disconnect(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ClientDisconnected) - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.infos), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertFalse(inst.will_close) - self.assertEqual(inst.error_task_class.serviced, False) - self.assertTrue(request.closed) - - def test_service_with_request_error_raises_disconnect(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - err_request = DummyRequest() - inst.requests = [request] - inst.parser_class = lambda x: err_request - inst.task_class = DummyTaskClass(RuntimeError) - inst.task_class.wrote_header = False - inst.error_task_class = DummyTaskClass(ClientDisconnected) - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertTrue(err_request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertEqual(len(inst.logger.infos), 0) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertFalse(inst.will_close) - self.assertEqual(inst.task_class.serviced, True) - self.assertEqual(inst.error_task_class.serviced, True) - self.assertTrue(request.closed) - - def test_cancel_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = () - inst.cancel() - self.assertEqual(inst.requests, []) - - def test_cancel_with_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [None] - inst.cancel() - self.assertEqual(inst.requests, []) - - -class DummySock(object): - blocking = False - closed = False - - def __init__(self): - self.sent = b"" - - def setblocking(self, *arg): - self.blocking = True - - def fileno(self): - return 100 - - def getpeername(self): - return "127.0.0.1" - - def getsockopt(self, level, option): - return 2048 - - def close(self): - self.closed = True - - def send(self, data): - self.sent += data - return len(data) - - -class DummyLock(object): - notified = False - - def __init__(self, acquirable=True): - self.acquirable = acquirable - - def acquire(self, val): - self.val = val - self.acquired = True - return self.acquirable - - def release(self): - self.released = True - - def notify(self): - self.notified = True - - def wait(self): - self.waited = True - - def __exit__(self, type, val, traceback): - self.acquire(True) - - def __enter__(self): - pass - - -class DummyBuffer(object): - closed = False - - def __init__(self, data, toraise=None): - self.data = data - self.toraise = toraise - - def get(self, *arg): - if self.toraise: - raise self.toraise - data = self.data - self.data = b"" - return data - - def skip(self, num, x): - self.skipped = num - - def __len__(self): - return len(self.data) - - def close(self): - self.closed = True - - -class DummyAdjustments(object): - outbuf_overflow = 1048576 - outbuf_high_watermark = 1048576 - inbuf_overflow = 512000 - cleanup_interval = 900 - url_scheme = "http" - channel_timeout = 300 - log_socket_errors = True - recv_bytes = 8192 - send_bytes = 1 - expose_tracebacks = True - ident = "waitress" - max_request_header_size = 10000 - - -class DummyServer(object): - trigger_pulled = False - adj = DummyAdjustments() - - def __init__(self): - self.tasks = [] - self.active_channels = {} - - def add_task(self, task): - self.tasks.append(task) - - def pull_trigger(self): - self.trigger_pulled = True - - -class DummyParser(object): - version = 1 - data = None - completed = True - empty = False - headers_finished = False - expect_continue = False - retval = None - error = None - connection_close = False - - def received(self, data): - self.data = data - if self.retval is not None: - return self.retval - return len(data) - - -class DummyRequest(object): - error = None - path = "/" - version = "1.0" - closed = False - - def __init__(self): - self.headers = {} - - def close(self): - self.closed = True - - -class DummyLogger(object): - def __init__(self): - self.exceptions = [] - self.infos = [] - self.warnings = [] - - def info(self, msg): - self.infos.append(msg) - - def exception(self, msg): - self.exceptions.append(msg) - - -class DummyError(object): - code = "431" - reason = "Bleh" - body = "My body" - - -class DummyTaskClass(object): - wrote_header = True - close_on_finish = False - serviced = False - - def __init__(self, toraise=None): - self.toraise = toraise - - def __call__(self, channel, request): - self.request = request - return self - - def service(self): - self.serviced = True - self.request.serviced = True - if self.toraise: - raise self.toraise diff --git a/waitress/tests/test_compat.py b/waitress/tests/test_compat.py deleted file mode 100644 index 37c2193..0000000 --- a/waitress/tests/test_compat.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest - - -class Test_unquote_bytes_to_wsgi(unittest.TestCase): - def _callFUT(self, v): - from waitress.compat import unquote_bytes_to_wsgi - - return unquote_bytes_to_wsgi(v) - - def test_highorder(self): - from waitress.compat import PY3 - - val = b"/a%C5%9B" - result = self._callFUT(val) - if PY3: # pragma: no cover - # PEP 3333 urlunquoted-latin1-decoded-bytes - self.assertEqual(result, "/aÅ\x9b") - else: # pragma: no cover - # sanity - self.assertEqual(result, b"/a\xc5\x9b") diff --git a/waitress/tests/test_functional.py b/waitress/tests/test_functional.py deleted file mode 100644 index 8f4b262..0000000 --- a/waitress/tests/test_functional.py +++ /dev/null @@ -1,1667 +0,0 @@ -import errno -import logging -import multiprocessing -import os -import signal -import socket -import string -import subprocess -import sys -import time -import unittest -from waitress import server -from waitress.compat import httplib, tobytes -from waitress.utilities import cleanup_unix_socket - -dn = os.path.dirname -here = dn(__file__) - - -class NullHandler(logging.Handler): # pragma: no cover - """A logging handler that swallows all emitted messages. - """ - - def emit(self, record): - pass - - -def start_server(app, svr, queue, **kwargs): # pragma: no cover - """Run a fixture application. - """ - logging.getLogger("waitress").addHandler(NullHandler()) - try_register_coverage() - svr(app, queue, **kwargs).run() - - -def try_register_coverage(): # pragma: no cover - # Hack around multiprocessing exiting early and not triggering coverage's - # atexit handler by always registering a signal handler - - if "COVERAGE_PROCESS_START" in os.environ: - def sigterm(*args): - sys.exit(0) - - signal.signal(signal.SIGTERM, sigterm) - - -class FixtureTcpWSGIServer(server.TcpWSGIServer): - """A version of TcpWSGIServer that relays back what it's bound to. - """ - - family = socket.AF_INET # Testing - - def __init__(self, application, queue, **kw): # pragma: no cover - # Coverage doesn't see this as it's ran in a separate process. - kw["port"] = 0 # Bind to any available port. - super(FixtureTcpWSGIServer, self).__init__(application, **kw) - host, port = self.socket.getsockname() - if os.name == "nt": - host = "127.0.0.1" - queue.put((host, port)) - - -class SubprocessTests(object): - - # For nose: all tests may be ran in separate processes. - _multiprocess_can_split_ = True - - exe = sys.executable - - server = None - - def start_subprocess(self, target, **kw): - # Spawn a server process. - self.queue = multiprocessing.Queue() - - if "COVERAGE_RCFILE" in os.environ: - os.environ["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"] - - self.proc = multiprocessing.Process( - target=start_server, args=(target, self.server, self.queue), kwargs=kw, - ) - self.proc.start() - - if self.proc.exitcode is not None: # pragma: no cover - raise RuntimeError("%s didn't start" % str(target)) - # Get the socket the server is listening on. - self.bound_to = self.queue.get(timeout=5) - self.sock = self.create_socket() - - def stop_subprocess(self): - if self.proc.exitcode is None: - self.proc.terminate() - self.sock.close() - # This give us one FD back ... - self.queue.close() - self.proc.join() - - def assertline(self, line, status, reason, version): - v, s, r = (x.strip() for x in line.split(None, 2)) - self.assertEqual(s, tobytes(status)) - self.assertEqual(r, tobytes(reason)) - self.assertEqual(v, tobytes(version)) - - def create_socket(self): - return socket.socket(self.server.family, socket.SOCK_STREAM) - - def connect(self): - self.sock.connect(self.bound_to) - - def make_http_connection(self): - raise NotImplementedError # pragma: no cover - - def send_check_error(self, to_send): - self.sock.send(to_send) - - -class TcpTests(SubprocessTests): - - server = FixtureTcpWSGIServer - - def make_http_connection(self): - return httplib.HTTPConnection(*self.bound_to) - - -class SleepyThreadTests(TcpTests, unittest.TestCase): - # test that sleepy thread doesnt block other requests - - def setUp(self): - from waitress.tests.fixtureapps import sleepy - - self.start_subprocess(sleepy.app) - - def tearDown(self): - self.stop_subprocess() - - def test_it(self): - getline = os.path.join(here, "fixtureapps", "getline.py") - cmds = ( - [self.exe, getline, "http://%s:%d/sleepy" % self.bound_to], - [self.exe, getline, "http://%s:%d/" % self.bound_to], - ) - r, w = os.pipe() - procs = [] - for cmd in cmds: - procs.append(subprocess.Popen(cmd, stdout=w)) - time.sleep(3) - for proc in procs: - if proc.returncode is not None: # pragma: no cover - proc.terminate() - proc.wait() - # the notsleepy response should always be first returned (it sleeps - # for 2 seconds, then returns; the notsleepy response should be - # processed in the meantime) - result = os.read(r, 10000) - os.close(r) - os.close(w) - self.assertEqual(result, b"notsleepy returnedsleepy returned") - - -class EchoTests(object): - def setUp(self): - from waitress.tests.fixtureapps import echo - - self.start_subprocess( - echo.app, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-for", "x-forwarded-proto"}, - clear_untrusted_proxy_headers=True, - ) - - def tearDown(self): - self.stop_subprocess() - - def _read_echo(self, fp): - from waitress.tests.fixtureapps import echo - - line, headers, body = read_http(fp) - return line, headers, echo.parse_response(body) - - def test_date_and_server(self): - to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("server"), "waitress") - self.assertTrue(headers.get("date")) - - def test_bad_host_header(self): - # https://corte.si/posts/code/pathod/pythonservers/index.html - to_send = "GET / HTTP/1.0\r\n Host: 0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "400", "Bad Request", "HTTP/1.0") - self.assertEqual(headers.get("server"), "waitress") - self.assertTrue(headers.get("date")) - - def test_send_with_body(self): - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" - to_send += "hello" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(echo.content_length, "5") - self.assertEqual(echo.body, b"hello") - - def test_send_empty_body(self): - to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(echo.content_length, "0") - self.assertEqual(echo.body, b"") - - def test_multiple_requests_with_body(self): - orig_sock = self.sock - for x in range(3): - self.sock = self.create_socket() - self.test_send_with_body() - self.sock.close() - self.sock = orig_sock - - def test_multiple_requests_without_body(self): - orig_sock = self.sock - for x in range(3): - self.sock = self.create_socket() - self.test_send_empty_body() - self.sock.close() - self.sock = orig_sock - - def test_without_crlf(self): - data = "Echo\r\nthis\r\nplease" - s = tobytes( - "GET / HTTP/1.0\r\n" - "Connection: close\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(int(echo.content_length), len(data)) - self.assertEqual(len(echo.body), len(data)) - self.assertEqual(echo.body, tobytes(data)) - - def test_large_body(self): - # 1024 characters. - body = "This string has 32 characters.\r\n" * 32 - s = tobytes( - "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(body), body) - ) - self.connect() - self.sock.send(s) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(echo.content_length, "1024") - self.assertEqual(echo.body, tobytes(body)) - - def test_many_clients(self): - conns = [] - for n in range(50): - h = self.make_http_connection() - h.request("GET", "/", headers={"Accept": "text/plain"}) - conns.append(h) - responses = [] - for h in conns: - response = h.getresponse() - self.assertEqual(response.status, 200) - responses.append(response) - for response in responses: - response.read() - for h in conns: - h.close() - - def test_chunking_request_without_content(self): - header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") - self.connect() - self.sock.send(header) - self.sock.send(b"0\r\n\r\n") - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(echo.body, b"") - self.assertEqual(echo.content_length, "0") - self.assertFalse("transfer-encoding" in headers) - - def test_chunking_request_with_content(self): - control_line = b"20;\r\n" # 20 hex = 32 dec - s = b"This string has 32 characters.\r\n" - expected = s * 12 - header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") - self.connect() - self.sock.send(header) - fp = self.sock.makefile("rb", 0) - for n in range(12): - self.sock.send(control_line) - self.sock.send(s) - self.sock.send(b"\r\n") # End the chunk - self.sock.send(b"0\r\n\r\n") - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(echo.body, expected) - self.assertEqual(echo.content_length, str(len(expected))) - self.assertFalse("transfer-encoding" in headers) - - def test_broken_chunked_encoding(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = "This string has 32 characters.\r\n" - to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" - to_send += control_line + s + "\r\n" - # garbage in input - to_send += "garbage\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # receiver caught garbage and turned it into a 400 - self.assertline(line, "400", "Bad Request", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - self.assertEqual(headers["content-type"], "text/plain") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_broken_chunked_encoding_missing_chunk_end(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = "This string has 32 characters.\r\n" - to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" - to_send += control_line + s - # garbage in input - to_send += "garbage" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # receiver caught garbage and turned it into a 400 - self.assertline(line, "400", "Bad Request", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(b"Chunk not properly terminated" in response_body) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - self.assertEqual(headers["content-type"], "text/plain") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_keepalive_http_10(self): - # Handling of Keep-Alive within HTTP 1.0 - data = "Default: Don't keep me alive" - s = tobytes( - "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - connection = response.getheader("Connection", "") - # We sent no Connection: Keep-Alive header - # Connection: close (or no header) is default. - self.assertTrue(connection != "Keep-Alive") - - def test_keepalive_http10_explicit(self): - # If header Connection: Keep-Alive is explicitly sent, - # we want to keept the connection open, we also need to return - # the corresponding header - data = "Keep me alive" - s = tobytes( - "GET / HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - connection = response.getheader("Connection", "") - self.assertEqual(connection, "Keep-Alive") - - def test_keepalive_http_11(self): - # Handling of Keep-Alive within HTTP 1.1 - - # All connections are kept alive, unless stated otherwise - data = "Default: Keep me alive" - s = tobytes( - "GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertTrue(response.getheader("connection") != "close") - - def test_keepalive_http11_explicit(self): - # Explicitly set keep-alive - data = "Default: Keep me alive" - s = tobytes( - "GET / HTTP/1.1\r\n" - "Connection: keep-alive\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertTrue(response.getheader("connection") != "close") - - def test_keepalive_http11_connclose(self): - # specifying Connection: close explicitly - data = "Don't keep me alive" - s = tobytes( - "GET / HTTP/1.1\r\n" - "Connection: close\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertEqual(response.getheader("connection"), "close") - - def test_proxy_headers(self): - to_send = ( - "GET / HTTP/1.0\r\n" - "Content-Length: 0\r\n" - "Host: www.google.com:8080\r\n" - "X-Forwarded-For: 192.168.1.1\r\n" - "X-Forwarded-Proto: https\r\n" - "X-Forwarded-Port: 5000\r\n\r\n" - ) - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("server"), "waitress") - self.assertTrue(headers.get("date")) - self.assertIsNone(echo.headers.get("X_FORWARDED_PORT")) - self.assertEqual(echo.headers["HOST"], "www.google.com:8080") - self.assertEqual(echo.scheme, "https") - self.assertEqual(echo.remote_addr, "192.168.1.1") - self.assertEqual(echo.remote_host, "192.168.1.1") - - -class PipeliningTests(object): - def setUp(self): - from waitress.tests.fixtureapps import echo - - self.start_subprocess(echo.app_body_only) - - def tearDown(self): - self.stop_subprocess() - - def test_pipelining(self): - s = ( - "GET / HTTP/1.0\r\n" - "Connection: %s\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" - ) - to_send = b"" - count = 25 - for n in range(count): - body = "Response #%d\r\n" % (n + 1) - if n + 1 < count: - conn = "keep-alive" - else: - conn = "close" - to_send += tobytes(s % (conn, len(body), body)) - - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - for n in range(count): - expect_body = tobytes("Response #%d\r\n" % (n + 1)) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - length = int(headers.get("content-length")) or None - response_body = fp.read(length) - self.assertEqual(int(status), 200) - self.assertEqual(length, len(response_body)) - self.assertEqual(response_body, expect_body) - - -class ExpectContinueTests(object): - def setUp(self): - from waitress.tests.fixtureapps import echo - - self.start_subprocess(echo.app_body_only) - - def tearDown(self): - self.stop_subprocess() - - def test_expect_continue(self): - # specifying Connection: close explicitly - data = "I have expectations" - to_send = tobytes( - "GET / HTTP/1.1\r\n" - "Connection: close\r\n" - "Content-Length: %d\r\n" - "Expect: 100-continue\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # continue status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - self.assertEqual(int(status), 100) - self.assertEqual(reason, b"Continue") - self.assertEqual(version, b"HTTP/1.1") - fp.readline() # blank line - line = fp.readline() # next status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - length = int(headers.get("content-length")) or None - response_body = fp.read(length) - self.assertEqual(int(status), 200) - self.assertEqual(length, len(response_body)) - self.assertEqual(response_body, tobytes(data)) - - -class BadContentLengthTests(object): - def setUp(self): - from waitress.tests.fixtureapps import badcl - - self.start_subprocess(badcl.app) - - def tearDown(self): - self.stop_subprocess() - - def test_short_body(self): - # check to see if server closes connection when body is too short - # for cl header - to_send = tobytes( - "GET /short_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get("content-length")) - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - self.assertNotEqual(content_length, len(response_body)) - self.assertEqual(len(response_body), content_length - 1) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote closed connection (despite keepalive header); not sure why - # first send succeeds - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_long_body(self): - # check server doesnt close connection when body is too short - # for cl header - to_send = tobytes( - "GET /long_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get("content-length")) or None - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes("abcdefgh")) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get("content-length")) or None - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - - -class NoContentLengthTests(object): - def setUp(self): - from waitress.tests.fixtureapps import nocl - - self.start_subprocess(nocl.app) - - def tearDown(self): - self.stop_subprocess() - - def test_http10_generator(self): - body = string.ascii_letters - to_send = ( - "GET / HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n\r\n" % len(body) - ) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("content-length"), None) - self.assertEqual(headers.get("connection"), "close") - self.assertEqual(response_body, tobytes(body)) - # remote closed connection (despite keepalive header), because - # generators cannot have a content-length divined - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http10_list(self): - body = string.ascii_letters - to_send = ( - "GET /list HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n\r\n" % len(body) - ) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers["content-length"], str(len(body))) - self.assertEqual(headers.get("connection"), "Keep-Alive") - self.assertEqual(response_body, tobytes(body)) - # remote keeps connection open because it divined the content length - # from a length-1 list - self.sock.send(to_send) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - - def test_http10_listlentwo(self): - body = string.ascii_letters - to_send = ( - "GET /list_lentwo HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n\r\n" % len(body) - ) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("content-length"), None) - self.assertEqual(headers.get("connection"), "close") - self.assertEqual(response_body, tobytes(body)) - # remote closed connection (despite keepalive header), because - # lists of length > 1 cannot have their content length divined - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http11_generator(self): - body = string.ascii_letters - to_send = "GET / HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - expected = b"" - for chunk in chunks(body, 10): - expected += tobytes( - "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) - ) - expected += b"0\r\n\r\n" - self.assertEqual(response_body, expected) - # connection is always closed at the end of a chunked response - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http11_list(self): - body = string.ascii_letters - to_send = "GET /list HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(headers["content-length"], str(len(body))) - self.assertEqual(response_body, tobytes(body)) - # remote keeps connection open because it divined the content length - # from a length-1 list - self.sock.send(to_send) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - - def test_http11_listlentwo(self): - body = string.ascii_letters - to_send = "GET /list_lentwo HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - expected = b"" - for chunk in (body[0], body[1:]): - expected += tobytes( - "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) - ) - expected += b"0\r\n\r\n" - self.assertEqual(response_body, expected) - # connection is always closed at the end of a chunked response - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class WriteCallbackTests(object): - def setUp(self): - from waitress.tests.fixtureapps import writecb - - self.start_subprocess(writecb.app) - - def tearDown(self): - self.stop_subprocess() - - def test_short_body(self): - # check to see if server closes connection when body is too short - # for cl header - to_send = tobytes( - "GET /short_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (5) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, 9) - self.assertNotEqual(cl, len(response_body)) - self.assertEqual(len(response_body), cl - 1) - self.assertEqual(response_body, tobytes("abcdefgh")) - # remote closed connection (despite keepalive header) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_long_body(self): - # check server doesnt close connection when body is too long - # for cl header - to_send = tobytes( - "GET /long_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - content_length = int(headers.get("content-length")) or None - self.assertEqual(content_length, 9) - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - - def test_equal_body(self): - # check server doesnt close connection when body is equal to - # cl header - to_send = tobytes( - "GET /equal_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - content_length = int(headers.get("content-length")) or None - self.assertEqual(content_length, 9) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - - def test_no_content_length(self): - # wtf happens when there's no content-length - to_send = tobytes( - "GET /no_content_length HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - line, headers, response_body = read_http(fp) - content_length = headers.get("content-length") - self.assertEqual(content_length, None) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote closed connection (despite keepalive header) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class TooLargeTests(object): - - toobig = 1050 - - def setUp(self): - from waitress.tests.fixtureapps import toolarge - - self.start_subprocess( - toolarge.app, max_request_header_size=1000, max_request_body_size=1000 - ) - - def tearDown(self): - self.stop_subprocess() - - def test_request_body_too_large_with_wrong_cl_http10(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # server trusts the content-length header; no pipelining, - # so request fulfilled, extra bytes are thrown away - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - line, headers, response_body = read_http(fp) - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http10(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # extra bytes are thrown away (no pipelining), connection closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http10_keepalive(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (assumed zero) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - line, headers, response_body = read_http(fp) - # next response overruns because the extra data appears to be - # header data - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http11(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # second response is an error response - line, headers, response_body = read_http(fp) - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http11_connclose(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\nConnection: close\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (5) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http11(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # server trusts the content-length header (assumed 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # server assumes pipelined requests due to http/1.1, and the first - # request was assumed c-l 0 because it had no content-length header, - # so entire body looks like the header of the subsequent request - # second response is an error response - line, headers, response_body = read_http(fp) - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http11_connclose(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\nConnection: close\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (assumed 0) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_chunked_encoding(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = "This string has 32 characters.\r\n" - to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" - repeat = control_line + s - to_send += repeat * ((self.toobig // len(repeat)) + 1) - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # body bytes counter caught a max_request_body_size overrun - self.assertline(line, "413", "Request Entity Too Large", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertEqual(headers["content-type"], "text/plain") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class InternalServerErrorTests(object): - def setUp(self): - from waitress.tests.fixtureapps import error - - self.start_subprocess(error.app, expose_tracebacks=True) - - def tearDown(self): - self.stop_subprocess() - - def test_before_start_response_http_10(self): - to_send = "GET /before_start_response HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_before_start_response_http_11(self): - to_send = "GET /before_start_response HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_before_start_response_http_11_close(self): - to_send = tobytes( - "GET /before_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), - ["connection", "content-length", "content-type", "date", "server"], - ) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http10(self): - to_send = "GET /after_start_response HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), - ["connection", "content-length", "content-type", "date", "server"], - ) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http11(self): - to_send = "GET /after_start_response HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http11_close(self): - to_send = tobytes( - "GET /after_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), - ["connection", "content-length", "content-type", "date", "server"], - ) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_write_cb(self): - to_send = "GET /after_write_cb HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(response_body, b"") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_in_generator(self): - to_send = "GET /in_generator HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(response_body, b"") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class FileWrapperTests(object): - def setUp(self): - from waitress.tests.fixtureapps import filewrapper - - self.start_subprocess(filewrapper.app) - - def tearDown(self): - self.stop_subprocess() - - def test_filelike_http11(self): - to_send = "GET /filelike HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_filelike_nocl_http11(self): - to_send = "GET /filelike_nocl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_filelike_shortcl_http11(self): - to_send = "GET /filelike_shortcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, 1) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377" in response_body) - # connection has not been closed - - def test_filelike_longcl_http11(self): - to_send = "GET /filelike_longcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_notfilelike_http11(self): - to_send = "GET /notfilelike HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_notfilelike_iobase_http11(self): - to_send = "GET /notfilelike_iobase HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_notfilelike_nocl_http11(self): - to_send = "GET /notfilelike_nocl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed (no content-length) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_shortcl_http11(self): - to_send = "GET /notfilelike_shortcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, 1) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377" in response_body) - # connection has not been closed - - def test_notfilelike_longcl_http11(self): - to_send = "GET /notfilelike_longcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body) + 10) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_filelike_http10(self): - to_send = "GET /filelike HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_filelike_nocl_http10(self): - to_send = "GET /filelike_nocl HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_http10(self): - to_send = "GET /notfilelike HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_nocl_http10(self): - to_send = "GET /notfilelike_nocl HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed (no content-length) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase): - pass - - -class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase): - pass - - -class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase): - pass - - -class TcpBadContentLengthTests(BadContentLengthTests, TcpTests, unittest.TestCase): - pass - - -class TcpNoContentLengthTests(NoContentLengthTests, TcpTests, unittest.TestCase): - pass - - -class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase): - pass - - -class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase): - pass - - -class TcpInternalServerErrorTests( - InternalServerErrorTests, TcpTests, unittest.TestCase -): - pass - - -class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase): - pass - - -if hasattr(socket, "AF_UNIX"): - - class FixtureUnixWSGIServer(server.UnixWSGIServer): - """A version of UnixWSGIServer that relays back what it's bound to. - """ - - family = socket.AF_UNIX # Testing - - def __init__(self, application, queue, **kw): # pragma: no cover - # Coverage doesn't see this as it's ran in a separate process. - # To permit parallel testing, use a PID-dependent socket. - kw["unix_socket"] = "/tmp/waitress.test-%d.sock" % os.getpid() - super(FixtureUnixWSGIServer, self).__init__(application, **kw) - queue.put(self.socket.getsockname()) - - class UnixTests(SubprocessTests): - - server = FixtureUnixWSGIServer - - def make_http_connection(self): - return UnixHTTPConnection(self.bound_to) - - def stop_subprocess(self): - super(UnixTests, self).stop_subprocess() - cleanup_unix_socket(self.bound_to) - - def send_check_error(self, to_send): - # Unlike inet domain sockets, Unix domain sockets can trigger a - # 'Broken pipe' error when the socket it closed. - try: - self.sock.send(to_send) - except socket.error as exc: - self.assertEqual(get_errno(exc), errno.EPIPE) - - class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase): - pass - - class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase): - pass - - class UnixExpectContinueTests(ExpectContinueTests, UnixTests, unittest.TestCase): - pass - - class UnixBadContentLengthTests( - BadContentLengthTests, UnixTests, unittest.TestCase - ): - pass - - class UnixNoContentLengthTests(NoContentLengthTests, UnixTests, unittest.TestCase): - pass - - class UnixWriteCallbackTests(WriteCallbackTests, UnixTests, unittest.TestCase): - pass - - class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase): - pass - - class UnixInternalServerErrorTests( - InternalServerErrorTests, UnixTests, unittest.TestCase - ): - pass - - class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase): - pass - - -def parse_headers(fp): - """Parses only RFC2822 headers from a file pointer. - """ - headers = {} - while True: - line = fp.readline() - if line in (b"\r\n", b"\n", b""): - break - line = line.decode("iso-8859-1") - name, value = line.strip().split(":", 1) - headers[name.lower().strip()] = value.lower().strip() - return headers - - -class UnixHTTPConnection(httplib.HTTPConnection): - """Patched version of HTTPConnection that uses Unix domain sockets. - """ - - def __init__(self, path): - httplib.HTTPConnection.__init__(self, "localhost") - self.path = path - - def connect(self): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(self.path) - self.sock = sock - - -class ConnectionClosed(Exception): - pass - - -# stolen from gevent -def read_http(fp): # pragma: no cover - try: - response_line = fp.readline() - except socket.error as exc: - fp.close() - # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET - if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054): - raise ConnectionClosed - raise - if not response_line: - raise ConnectionClosed - - header_lines = [] - while True: - line = fp.readline() - if line in (b"\r\n", b"\r\n", b""): - break - else: - header_lines.append(line) - headers = dict() - for x in header_lines: - x = x.strip() - if not x: - continue - key, value = x.split(b": ", 1) - key = key.decode("iso-8859-1").lower() - value = value.decode("iso-8859-1") - assert key not in headers, "%s header duplicated" % key - headers[key] = value - - if "content-length" in headers: - num = int(headers["content-length"]) - body = b"" - left = num - while left > 0: - data = fp.read(left) - if not data: - break - body += data - left -= len(data) - else: - # read until EOF - body = fp.read() - - return response_line, headers, body - - -# stolen from gevent -def get_errno(exc): # pragma: no cover - """ Get the error code out of socket.error objects. - socket.error in <2.5 does not have errno attribute - socket.error in 3.x does not allow indexing access - e.args[0] works for all. - There are cases when args[0] is not errno. - i.e. http://bugs.python.org/issue6471 - Maybe there are cases when errno is set, but it is not the first argument? - """ - try: - if exc.errno is not None: - return exc.errno - except AttributeError: - pass - try: - return exc.args[0] - except IndexError: - return None - - -def chunks(l, n): - """ Yield successive n-sized chunks from l. - """ - for i in range(0, len(l), n): - yield l[i : i + n] diff --git a/waitress/tests/test_init.py b/waitress/tests/test_init.py deleted file mode 100644 index f9b91d7..0000000 --- a/waitress/tests/test_init.py +++ /dev/null @@ -1,51 +0,0 @@ -import unittest - - -class Test_serve(unittest.TestCase): - def _callFUT(self, app, **kw): - from waitress import serve - - return serve(app, **kw) - - def test_it(self): - server = DummyServerFactory() - app = object() - result = self._callFUT(app, _server=server, _quiet=True) - self.assertEqual(server.app, app) - self.assertEqual(result, None) - self.assertEqual(server.ran, True) - - -class Test_serve_paste(unittest.TestCase): - def _callFUT(self, app, **kw): - from waitress import serve_paste - - return serve_paste(app, None, **kw) - - def test_it(self): - server = DummyServerFactory() - app = object() - result = self._callFUT(app, _server=server, _quiet=True) - self.assertEqual(server.app, app) - self.assertEqual(result, 0) - self.assertEqual(server.ran, True) - - -class DummyServerFactory(object): - ran = False - - def __call__(self, app, **kw): - self.adj = DummyAdj(kw) - self.app = app - self.kw = kw - return self - - def run(self): - self.ran = True - - -class DummyAdj(object): - verbose = False - - def __init__(self, kw): - self.__dict__.update(kw) diff --git a/waitress/tests/test_parser.py b/waitress/tests/test_parser.py deleted file mode 100644 index 91837c7..0000000 --- a/waitress/tests/test_parser.py +++ /dev/null @@ -1,732 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""HTTP Request Parser tests -""" -import unittest - -from waitress.compat import text_, tobytes - - -class TestHTTPRequestParser(unittest.TestCase): - def setUp(self): - from waitress.parser import HTTPRequestParser - from waitress.adjustments import Adjustments - - my_adj = Adjustments() - self.parser = HTTPRequestParser(my_adj) - - def test_get_body_stream_None(self): - self.parser.body_recv = None - result = self.parser.get_body_stream() - self.assertEqual(result.getvalue(), b"") - - def test_get_body_stream_nonNone(self): - body_rcv = DummyBodyStream() - self.parser.body_rcv = body_rcv - result = self.parser.get_body_stream() - self.assertEqual(result, body_rcv) - - def test_received_get_no_headers(self): - data = b"HTTP/1.0 GET /foobar\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 24) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_bad_host_header(self): - from waitress.utilities import BadRequest - - data = b"HTTP/1.0 GET /foobar\r\n Host: foo\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 36) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.error.__class__, BadRequest) - - def test_received_bad_transfer_encoding(self): - from waitress.utilities import ServerNotImplemented - - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: foo\r\n" - b"\r\n" - b"1d;\r\n" - b"This string has 29 characters\r\n" - b"0\r\n\r\n" - ) - result = self.parser.received(data) - self.assertEqual(result, 48) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.error.__class__, ServerNotImplemented) - - def test_received_nonsense_nothing(self): - data = b"\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 4) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_no_doublecr(self): - data = b"GET /foobar HTTP/8.4\r\n" - result = self.parser.received(data) - self.assertEqual(result, 22) - self.assertFalse(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_already_completed(self): - self.parser.completed = True - result = self.parser.received(b"a") - self.assertEqual(result, 0) - - def test_received_cl_too_large(self): - from waitress.utilities import RequestEntityTooLarge - - self.parser.adj.max_request_body_size = 2 - data = b"GET /foobar HTTP/8.4\r\nContent-Length: 10\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 44) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) - - def test_received_headers_too_large(self): - from waitress.utilities import RequestHeaderFieldsTooLarge - - self.parser.adj.max_request_header_size = 2 - data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 34) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestHeaderFieldsTooLarge)) - - def test_received_body_too_large(self): - from waitress.utilities import RequestEntityTooLarge - - self.parser.adj.max_request_body_size = 2 - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: chunked\r\n" - b"X-Foo: 1\r\n" - b"\r\n" - b"1d;\r\n" - b"This string has 29 characters\r\n" - b"0\r\n\r\n" - ) - - result = self.parser.received(data) - self.assertEqual(result, 62) - self.parser.received(data[result:]) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) - - def test_received_error_from_parser(self): - from waitress.utilities import BadRequest - - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: chunked\r\n" - b"X-Foo: 1\r\n" - b"\r\n" - b"garbage\r\n" - ) - # header - result = self.parser.received(data) - # body - result = self.parser.received(data[result:]) - self.assertEqual(result, 9) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, BadRequest)) - - def test_received_chunked_completed_sets_content_length(self): - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: chunked\r\n" - b"X-Foo: 1\r\n" - b"\r\n" - b"1d;\r\n" - b"This string has 29 characters\r\n" - b"0\r\n\r\n" - ) - result = self.parser.received(data) - self.assertEqual(result, 62) - data = data[result:] - result = self.parser.received(data) - self.assertTrue(self.parser.completed) - self.assertTrue(self.parser.error is None) - self.assertEqual(self.parser.headers["CONTENT_LENGTH"], "29") - - def test_parse_header_gardenpath(self): - data = b"GET /foobar HTTP/8.4\r\nfoo: bar\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.first_line, b"GET /foobar HTTP/8.4") - self.assertEqual(self.parser.headers["FOO"], "bar") - - def test_parse_header_no_cr_in_headerplus(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4" - - try: - self.parser.parse_header(data) - except ParsingError: - pass - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_bad_content_length(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\ncontent-length: abc\r\n" - - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Content-Length is invalid", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_multiple_content_length(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\ncontent-length: 10\r\ncontent-length: 20\r\n" - - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Content-Length is invalid", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_11_te_chunked(self): - # NB: test that capitalization of header value is unimportant - data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: ChUnKed\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.body_rcv.__class__.__name__, "ChunkedReceiver") - - def test_parse_header_transfer_encoding_invalid(self): - from waitress.parser import TransferEncodingNotImplemented - - data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_transfer_encoding_invalid_multiple(self): - from waitress.parser import TransferEncodingNotImplemented - - data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\ntransfer-encoding: chunked\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_transfer_encoding_invalid_whitespace(self): - from waitress.parser import TransferEncodingNotImplemented - - data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding:\x85chunked\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_transfer_encoding_invalid_unicode(self): - from waitress.parser import TransferEncodingNotImplemented - - # This is the binary encoding for the UTF-8 character - # https://www.compart.com/en/unicode/U+212A "unicode character "K"" - # which if waitress were to accidentally do the wrong thing get - # lowercased to just the ascii "k" due to unicode collisions during - # transformation - data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding: chun\xe2\x84\xaaed\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_11_expect_continue(self): - data = b"GET /foobar HTTP/1.1\r\nexpect: 100-continue\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.expect_continue, True) - - def test_parse_header_connection_close(self): - data = b"GET /foobar HTTP/1.1\r\nConnection: close\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.connection_close, True) - - def test_close_with_body_rcv(self): - body_rcv = DummyBodyStream() - self.parser.body_rcv = body_rcv - self.parser.close() - self.assertTrue(body_rcv.closed) - - def test_close_with_no_body_rcv(self): - self.parser.body_rcv = None - self.parser.close() # doesn't raise - - def test_parse_header_lf_only(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\nfoo: bar" - - try: - self.parser.parse_header(data) - except ParsingError: - pass - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_cr_only(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\rfoo: bar" - try: - self.parser.parse_header(data) - except ParsingError: - pass - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_extra_lf_in_header(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\nfoo: \nbar\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Bare CR or LF found in header line", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_extra_lf_in_first_line(self): - from waitress.parser import ParsingError - - data = b"GET /foobar\n HTTP/8.4\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Bare CR or LF found in HTTP message", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_whitespace(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\nfoo : bar\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_whitespace_vtab(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo:\x0bbar\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_no_colon(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nnotvalid\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_folding_spacing(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\n\t\x0bbaz\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_chars(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: \x0bbaz\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_empty(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nempty:\r\n" - self.parser.parse_header(data) - - self.assertIn("EMPTY", self.parser.headers) - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["EMPTY"], "") - self.assertEqual(self.parser.headers["FOO"], "bar") - - def test_parse_header_multiple_values(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever, more, please, yes\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") - - def test_parse_header_multiple_values_header_folded(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more, please, yes\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") - - def test_parse_header_multiple_values_header_folded_multiple(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more\r\nfoo: please, yes\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") - - def test_parse_header_multiple_values_extra_space(self): - # Tests errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: abrowser/0.001 (C O M M E N T)\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "abrowser/0.001 (C O M M E N T)") - - def test_parse_header_invalid_backtrack_bad(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\x10\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_short_values(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\none: 1\r\ntwo: 22\r\n" - self.parser.parse_header(data) - - self.assertIn("ONE", self.parser.headers) - self.assertIn("TWO", self.parser.headers) - self.assertEqual(self.parser.headers["ONE"], "1") - self.assertEqual(self.parser.headers["TWO"], "22") - - -class Test_split_uri(unittest.TestCase): - def _callFUT(self, uri): - from waitress.parser import split_uri - - ( - self.proxy_scheme, - self.proxy_netloc, - self.path, - self.query, - self.fragment, - ) = split_uri(uri) - - def test_split_uri_unquoting_unneeded(self): - self._callFUT(b"http://localhost:8080/abc def") - self.assertEqual(self.path, "/abc def") - - def test_split_uri_unquoting_needed(self): - self._callFUT(b"http://localhost:8080/abc%20def") - self.assertEqual(self.path, "/abc def") - - def test_split_url_with_query(self): - self._callFUT(b"http://localhost:8080/abc?a=1&b=2") - self.assertEqual(self.path, "/abc") - self.assertEqual(self.query, "a=1&b=2") - - def test_split_url_with_query_empty(self): - self._callFUT(b"http://localhost:8080/abc?") - self.assertEqual(self.path, "/abc") - self.assertEqual(self.query, "") - - def test_split_url_with_fragment(self): - self._callFUT(b"http://localhost:8080/#foo") - self.assertEqual(self.path, "/") - self.assertEqual(self.fragment, "foo") - - def test_split_url_https(self): - self._callFUT(b"https://localhost:8080/") - self.assertEqual(self.path, "/") - self.assertEqual(self.proxy_scheme, "https") - self.assertEqual(self.proxy_netloc, "localhost:8080") - - def test_split_uri_unicode_error_raises_parsing_error(self): - # See https://github.com/Pylons/waitress/issues/64 - from waitress.parser import ParsingError - - # Either pass or throw a ParsingError, just don't throw another type of - # exception as that will cause the connection to close badly: - try: - self._callFUT(b"/\xd0") - except ParsingError: - pass - - def test_split_uri_path(self): - self._callFUT(b"//testing/whatever") - self.assertEqual(self.path, "//testing/whatever") - self.assertEqual(self.proxy_scheme, "") - self.assertEqual(self.proxy_netloc, "") - self.assertEqual(self.query, "") - self.assertEqual(self.fragment, "") - - def test_split_uri_path_query(self): - self._callFUT(b"//testing/whatever?a=1&b=2") - self.assertEqual(self.path, "//testing/whatever") - self.assertEqual(self.proxy_scheme, "") - self.assertEqual(self.proxy_netloc, "") - self.assertEqual(self.query, "a=1&b=2") - self.assertEqual(self.fragment, "") - - def test_split_uri_path_query_fragment(self): - self._callFUT(b"//testing/whatever?a=1&b=2#fragment") - self.assertEqual(self.path, "//testing/whatever") - self.assertEqual(self.proxy_scheme, "") - self.assertEqual(self.proxy_netloc, "") - self.assertEqual(self.query, "a=1&b=2") - self.assertEqual(self.fragment, "fragment") - - -class Test_get_header_lines(unittest.TestCase): - def _callFUT(self, data): - from waitress.parser import get_header_lines - - return get_header_lines(data) - - def test_get_header_lines(self): - result = self._callFUT(b"slam\r\nslim") - self.assertEqual(result, [b"slam", b"slim"]) - - def test_get_header_lines_folded(self): - # From RFC2616: - # HTTP/1.1 header field values can be folded onto multiple lines if the - # continuation line begins with a space or horizontal tab. All linear - # white space, including folding, has the same semantics as SP. A - # recipient MAY replace any linear white space with a single SP before - # interpreting the field value or forwarding the message downstream. - - # We are just preserving the whitespace that indicates folding. - result = self._callFUT(b"slim\r\n slam") - self.assertEqual(result, [b"slim slam"]) - - def test_get_header_lines_tabbed(self): - result = self._callFUT(b"slam\r\n\tslim") - self.assertEqual(result, [b"slam\tslim"]) - - def test_get_header_lines_malformed(self): - # https://corte.si/posts/code/pathod/pythonservers/index.html - from waitress.parser import ParsingError - - self.assertRaises(ParsingError, self._callFUT, b" Host: localhost\r\n\r\n") - - -class Test_crack_first_line(unittest.TestCase): - def _callFUT(self, line): - from waitress.parser import crack_first_line - - return crack_first_line(line) - - def test_crack_first_line_matchok(self): - result = self._callFUT(b"GET / HTTP/1.0") - self.assertEqual(result, (b"GET", b"/", b"1.0")) - - def test_crack_first_line_lowercase_method(self): - from waitress.parser import ParsingError - - self.assertRaises(ParsingError, self._callFUT, b"get / HTTP/1.0") - - def test_crack_first_line_nomatch(self): - result = self._callFUT(b"GET / bleh") - self.assertEqual(result, (b"", b"", b"")) - - result = self._callFUT(b"GET /info?txtAirPlay&txtRAOP RTSP/1.0") - self.assertEqual(result, (b"", b"", b"")) - - def test_crack_first_line_missing_version(self): - result = self._callFUT(b"GET /") - self.assertEqual(result, (b"GET", b"/", b"")) - - -class TestHTTPRequestParserIntegration(unittest.TestCase): - def setUp(self): - from waitress.parser import HTTPRequestParser - from waitress.adjustments import Adjustments - - my_adj = Adjustments() - self.parser = HTTPRequestParser(my_adj) - - def feed(self, data): - parser = self.parser - - for n in range(100): # make sure we never loop forever - consumed = parser.received(data) - data = data[consumed:] - - if parser.completed: - return - raise ValueError("Looping") # pragma: no cover - - def testSimpleGET(self): - data = ( - b"GET /foobar HTTP/8.4\r\n" - b"FirstName: mickey\r\n" - b"lastname: Mouse\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - parser = self.parser - self.feed(data) - self.assertTrue(parser.completed) - self.assertEqual(parser.version, "8.4") - self.assertFalse(parser.empty) - self.assertEqual( - parser.headers, - {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "6",}, - ) - self.assertEqual(parser.path, "/foobar") - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.query, "") - self.assertEqual(parser.proxy_scheme, "") - self.assertEqual(parser.proxy_netloc, "") - self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") - - def testComplexGET(self): - data = ( - b"GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4\r\n" - b"FirstName: mickey\r\n" - b"lastname: Mouse\r\n" - b"content-length: 10\r\n" - b"\r\n" - b"Hello mickey." - ) - parser = self.parser - self.feed(data) - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.version, "8.4") - self.assertFalse(parser.empty) - self.assertEqual( - parser.headers, - {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "10"}, - ) - # path should be utf-8 encoded - self.assertEqual( - tobytes(parser.path).decode("utf-8"), - text_(b"/foo/a++/\xc3\xa4=&a:int", "utf-8"), - ) - self.assertEqual( - parser.query, "d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6" - ) - self.assertEqual(parser.get_body_stream().getvalue(), b"Hello mick") - - def testProxyGET(self): - data = ( - b"GET https://example.com:8080/foobar HTTP/8.4\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - parser = self.parser - self.feed(data) - self.assertTrue(parser.completed) - self.assertEqual(parser.version, "8.4") - self.assertFalse(parser.empty) - self.assertEqual(parser.headers, {"CONTENT_LENGTH": "6"}) - self.assertEqual(parser.path, "/foobar") - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.proxy_scheme, "https") - self.assertEqual(parser.proxy_netloc, "example.com:8080") - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.query, "") - self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") - - def testDuplicateHeaders(self): - # Ensure that headers with the same key get concatenated as per - # RFC2616. - data = ( - b"GET /foobar HTTP/8.4\r\n" - b"x-forwarded-for: 10.11.12.13\r\n" - b"x-forwarded-for: unknown,127.0.0.1\r\n" - b"X-Forwarded_for: 255.255.255.255\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - self.feed(data) - self.assertTrue(self.parser.completed) - self.assertEqual( - self.parser.headers, - { - "CONTENT_LENGTH": "6", - "X_FORWARDED_FOR": "10.11.12.13, unknown,127.0.0.1", - }, - ) - - def testSpoofedHeadersDropped(self): - data = ( - b"GET /foobar HTTP/8.4\r\n" - b"x-auth_user: bob\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - self.feed(data) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {"CONTENT_LENGTH": "6",}) - - -class DummyBodyStream(object): - def getfile(self): - return self - - def getbuf(self): - return self - - def close(self): - self.closed = True diff --git a/waitress/tests/test_proxy_headers.py b/waitress/tests/test_proxy_headers.py deleted file mode 100644 index 15b4a08..0000000 --- a/waitress/tests/test_proxy_headers.py +++ /dev/null @@ -1,724 +0,0 @@ -import unittest - -from waitress.compat import tobytes - - -class TestProxyHeadersMiddleware(unittest.TestCase): - def _makeOne(self, app, **kw): - from waitress.proxy_headers import proxy_headers_middleware - - return proxy_headers_middleware(app, **kw) - - def _callFUT(self, app, **kw): - response = DummyResponse() - environ = DummyEnviron(**kw) - - def start_response(status, response_headers): - response.status = status - response.headers = response_headers - - response.steps = list(app(environ, start_response)) - response.body = b"".join(tobytes(s) for s in response.steps) - return response - - def test_get_environment_values_w_scheme_override_untrusted(self): - inner = DummyApp() - app = self._makeOne(inner) - response = self._callFUT( - app, headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",} - ) - self.assertEqual(response.status, "200 OK") - self.assertEqual(inner.environ["wsgi.url_scheme"], "http") - - def test_get_environment_values_w_scheme_override_trusted(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"x-forwarded-proto"}, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 8080], - headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",}, - ) - - environ = inner.environ - self.assertEqual(response.status, "200 OK") - self.assertEqual(environ["SERVER_PORT"], "443") - self.assertEqual(environ["SERVER_NAME"], "localhost") - self.assertEqual(environ["REMOTE_ADDR"], "192.168.1.1") - self.assertEqual(environ["HTTP_X_FOO"], "BAR") - - def test_get_environment_values_w_bogus_scheme_override(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"x-forwarded-proto"}, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 80], - headers={ - "X_FOO": "BAR", - "X_FORWARDED_PROTO": "http://p02n3e.com?url=http", - }, - ) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) - - def test_get_environment_warning_other_proxy_headers(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - log_untrusted=True, - logger=logger, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 80], - headers={ - "X_FORWARDED_FOR": "[2001:db8::1]", - "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", - }, - ) - self.assertEqual(response.status, "200 OK") - - self.assertEqual(len(logger.logged), 1) - - environ = inner.environ - self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_get_environment_contains_all_headers_including_untrusted(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-by"}, - clear_untrusted=False, - ) - headers_orig = { - "X_FORWARDED_FOR": "198.51.100.2", - "X_FORWARDED_BY": "Waitress", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.org", - } - response = self._callFUT( - app, addr=["192.168.1.1", 80], headers=headers_orig.copy(), - ) - self.assertEqual(response.status, "200 OK") - environ = inner.environ - for k, expected in headers_orig.items(): - result = environ["HTTP_%s" % k] - self.assertEqual(result, expected) - - def test_get_environment_contains_only_trusted_headers(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-by"}, - clear_untrusted=True, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 80], - headers={ - "X_FORWARDED_FOR": "198.51.100.2", - "X_FORWARDED_BY": "Waitress", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.org", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["HTTP_X_FORWARDED_BY"], "Waitress") - self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) - self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) - self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) - - def test_get_environment_clears_headers_if_untrusted_proxy(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-by"}, - clear_untrusted=True, - ) - response = self._callFUT( - app, - addr=["192.168.1.255", 80], - headers={ - "X_FORWARDED_FOR": "198.51.100.2", - "X_FORWARDED_BY": "Waitress", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.org", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertNotIn("HTTP_X_FORWARDED_BY", environ) - self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) - self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) - self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) - - def test_parse_proxy_headers_forwarded_for(self): - inner = DummyApp() - app = self._makeOne( - inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_FOR": "192.0.2.1"}) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "192.0.2.1") - - def test_parse_proxy_headers_forwarded_for_v6_missing_brackets(self): - inner = DummyApp() - app = self._makeOne( - inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_FOR": "2001:db8::0"}) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::0") - - def test_parse_proxy_headers_forwared_for_multiple(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT( - app, headers={"X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1"} - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - - def test_parse_forwarded_multiple_proxies_trust_only_two(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - "For=192.0.2.1;host=fake.com, " - "For=198.51.100.2;host=example.com:8080, " - "For=203.0.113.1" - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_forwarded_multiple_proxies(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - 'for="[2001:db8::1]:3821";host="example.com:8443";proto="https", ' - 'for=192.0.2.1;host="example.internal:8080"' - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") - self.assertEqual(environ["REMOTE_PORT"], "3821") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8443") - self.assertEqual(environ["SERVER_PORT"], "8443") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_forwarded_multiple_proxies_minimal(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - 'for="[2001:db8::1]";proto="https", ' - 'for=192.0.2.1;host="example.org"' - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") - self.assertEqual(environ["SERVER_NAME"], "example.org") - self.assertEqual(environ["HTTP_HOST"], "example.org") - self.assertEqual(environ["SERVER_PORT"], "443") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_proxy_headers_forwarded_host_with_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com:8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_proxy_headers_forwarded_host_without_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com") - self.assertEqual(environ["SERVER_PORT"], "80") - - def test_parse_proxy_headers_forwarded_host_with_forwarded_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - "x-forwarded-port", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com", - "X_FORWARDED_PORT": "8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - "x-forwarded-port", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com, example.org", - "X_FORWARDED_PORT": "8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port_limit_one_trusted( - self, - ): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - "x-forwarded-port", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com, example.org", - "X_FORWARDED_PORT": "8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "203.0.113.1") - self.assertEqual(environ["SERVER_NAME"], "example.org") - self.assertEqual(environ["HTTP_HOST"], "example.org:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_forwarded(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": "For=198.51.100.2:5858;host=example.com:8080;proto=https", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["REMOTE_PORT"], "5858") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_forwarded_empty_pair(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, headers={"FORWARDED": "For=198.51.100.2;;proto=https;by=_unused",} - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - - def test_parse_forwarded_pair_token_whitespace(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, headers={"FORWARDED": "For=198.51.100.2; proto =https",} - ) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "Forwarded" malformed', response.body) - - def test_parse_forwarded_pair_value_whitespace(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, headers={"FORWARDED": 'For= "198.51.100.2"; proto =https',} - ) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "Forwarded" malformed', response.body) - - def test_parse_forwarded_pair_no_equals(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT(app, headers={"FORWARDED": "For"}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "Forwarded" malformed', response.body) - - def test_parse_forwarded_warning_unknown_token(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - logger=logger, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - "For=198.51.100.2;host=example.com:8080;proto=https;" - 'unknown="yolo"' - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - self.assertEqual(len(logger.logged), 1) - self.assertIn("Unknown Forwarded token", logger.logged[0]) - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_no_valid_proxy_headers(self): - inner = DummyApp() - app = self._makeOne(inner, trusted_proxy="*", trusted_proxy_count=1,) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "198.51.100.2", - "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") - self.assertEqual(environ["SERVER_NAME"], "localhost") - self.assertEqual(environ["HTTP_HOST"], "192.168.1.1:80") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "http") - - def test_parse_multiple_x_forwarded_proto(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-proto"}, - logger=logger, - ) - response = self._callFUT(app, headers={"X_FORWARDED_PROTO": "http, https",}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) - - def test_parse_multiple_x_forwarded_port(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-port"}, - logger=logger, - ) - response = self._callFUT(app, headers={"X_FORWARDED_PORT": "443, 80",}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Port" malformed', response.body) - - def test_parse_forwarded_port_wrong_proto_port_80(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={ - "x-forwarded-port", - "x-forwarded-host", - "x-forwarded-proto", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_PORT": "80", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.com", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:80") - self.assertEqual(environ["SERVER_PORT"], "80") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_forwarded_port_wrong_proto_port_443(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={ - "x-forwarded-port", - "x-forwarded-host", - "x-forwarded-proto", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_PORT": "443", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:443") - self.assertEqual(environ["SERVER_PORT"], "443") - self.assertEqual(environ["wsgi.url_scheme"], "http") - - def test_parse_forwarded_for_bad_quote(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_FOR": '"foo'}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-For" malformed', response.body) - - def test_parse_forwarded_host_bad_quote(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-host"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_HOST": '"foo'}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Host" malformed', response.body) - - -class DummyLogger(object): - def __init__(self): - self.logged = [] - - def warning(self, msg, *args): - self.logged.append(msg % args) - - -class DummyApp(object): - def __call__(self, environ, start_response): - self.environ = environ - start_response("200 OK", [("Content-Type", "text/plain")]) - yield "hello" - - -class DummyResponse(object): - status = None - headers = None - body = None - - -def DummyEnviron( - addr=("127.0.0.1", 8080), scheme="http", server="localhost", headers=None, -): - environ = { - "REMOTE_ADDR": addr[0], - "REMOTE_HOST": addr[0], - "REMOTE_PORT": addr[1], - "SERVER_PORT": str(addr[1]), - "SERVER_NAME": server, - "wsgi.url_scheme": scheme, - "HTTP_HOST": "192.168.1.1:80", - } - if headers: - environ.update( - { - "HTTP_" + key.upper().replace("-", "_"): value - for key, value in headers.items() - } - ) - return environ diff --git a/waitress/tests/test_receiver.py b/waitress/tests/test_receiver.py deleted file mode 100644 index b4910bb..0000000 --- a/waitress/tests/test_receiver.py +++ /dev/null @@ -1,242 +0,0 @@ -import unittest - - -class TestFixedStreamReceiver(unittest.TestCase): - def _makeOne(self, cl, buf): - from waitress.receiver import FixedStreamReceiver - - return FixedStreamReceiver(cl, buf) - - def test_received_remain_lt_1(self): - buf = DummyBuffer() - inst = self._makeOne(0, buf) - result = inst.received("a") - self.assertEqual(result, 0) - self.assertEqual(inst.completed, True) - - def test_received_remain_lte_datalen(self): - buf = DummyBuffer() - inst = self._makeOne(1, buf) - result = inst.received("aa") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, True) - self.assertEqual(inst.completed, 1) - self.assertEqual(inst.remain, 0) - self.assertEqual(buf.data, ["a"]) - - def test_received_remain_gt_datalen(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - result = inst.received("aa") - self.assertEqual(result, 2) - self.assertEqual(inst.completed, False) - self.assertEqual(inst.remain, 8) - self.assertEqual(buf.data, ["aa"]) - - def test_getfile(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - self.assertEqual(inst.getfile(), buf) - - def test_getbuf(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - self.assertEqual(inst.getbuf(), buf) - - def test___len__(self): - buf = DummyBuffer(["1", "2"]) - inst = self._makeOne(10, buf) - self.assertEqual(inst.__len__(), 2) - - -class TestChunkedReceiver(unittest.TestCase): - def _makeOne(self, buf): - from waitress.receiver import ChunkedReceiver - - return ChunkedReceiver(buf) - - def test_alreadycompleted(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.completed = True - result = inst.received(b"a") - self.assertEqual(result, 0) - self.assertEqual(inst.completed, True) - - def test_received_remain_gt_zero(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.chunk_remainder = 100 - result = inst.received(b"a") - self.assertEqual(inst.chunk_remainder, 99) - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_control_line_notfinished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"a") - self.assertEqual(inst.control_line, b"a") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_control_line_finished_garbage_in_input(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"garbage\r\n") - self.assertEqual(result, 9) - self.assertTrue(inst.error) - - def test_received_control_line_finished_all_chunks_not_received(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"a;discard\r\n") - self.assertEqual(inst.control_line, b"") - self.assertEqual(inst.chunk_remainder, 10) - self.assertEqual(inst.all_chunks_received, False) - self.assertEqual(result, 11) - self.assertEqual(inst.completed, False) - - def test_received_control_line_finished_all_chunks_received(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"0;discard\r\n") - self.assertEqual(inst.control_line, b"") - self.assertEqual(inst.all_chunks_received, True) - self.assertEqual(result, 11) - self.assertEqual(inst.completed, False) - - def test_received_trailer_startswith_crlf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"\r\n") - self.assertEqual(result, 2) - self.assertEqual(inst.completed, True) - - def test_received_trailer_startswith_lf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"\n") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_trailer_not_finished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"a") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_trailer_finished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"abc\r\n\r\n") - self.assertEqual(inst.trailer, b"abc\r\n\r\n") - self.assertEqual(result, 7) - self.assertEqual(inst.completed, True) - - def test_getfile(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - self.assertEqual(inst.getfile(), buf) - - def test_getbuf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - self.assertEqual(inst.getbuf(), buf) - - def test___len__(self): - buf = DummyBuffer(["1", "2"]) - inst = self._makeOne(buf) - self.assertEqual(inst.__len__(), 2) - - def test_received_chunk_is_properly_terminated(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - data = b"4\r\nWiki\r\n" - result = inst.received(data) - self.assertEqual(result, len(data)) - self.assertEqual(inst.completed, False) - self.assertEqual(buf.data[0], b"Wiki") - - def test_received_chunk_not_properly_terminated(self): - from waitress.utilities import BadRequest - - buf = DummyBuffer() - inst = self._makeOne(buf) - data = b"4\r\nWikibadchunk\r\n" - result = inst.received(data) - self.assertEqual(result, len(data)) - self.assertEqual(inst.completed, False) - self.assertEqual(buf.data[0], b"Wiki") - self.assertEqual(inst.error.__class__, BadRequest) - - def test_received_multiple_chunks(self): - from waitress.utilities import BadRequest - - buf = DummyBuffer() - inst = self._makeOne(buf) - data = ( - b"4\r\n" - b"Wiki\r\n" - b"5\r\n" - b"pedia\r\n" - b"E\r\n" - b" in\r\n" - b"\r\n" - b"chunks.\r\n" - b"0\r\n" - b"\r\n" - ) - result = inst.received(data) - self.assertEqual(result, len(data)) - self.assertEqual(inst.completed, True) - self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") - self.assertEqual(inst.error, None) - - def test_received_multiple_chunks_split(self): - from waitress.utilities import BadRequest - - buf = DummyBuffer() - inst = self._makeOne(buf) - data1 = b"4\r\nWiki\r" - result = inst.received(data1) - self.assertEqual(result, len(data1)) - - data2 = ( - b"\n5\r\n" - b"pedia\r\n" - b"E\r\n" - b" in\r\n" - b"\r\n" - b"chunks.\r\n" - b"0\r\n" - b"\r\n" - ) - - result = inst.received(data2) - self.assertEqual(result, len(data2)) - - self.assertEqual(inst.completed, True) - self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") - self.assertEqual(inst.error, None) - - -class DummyBuffer(object): - def __init__(self, data=None): - if data is None: - data = [] - self.data = data - - def append(self, s): - self.data.append(s) - - def getfile(self): - return self - - def __len__(self): - return len(self.data) diff --git a/waitress/tests/test_regression.py b/waitress/tests/test_regression.py deleted file mode 100644 index 3c4c6c2..0000000 --- a/waitress/tests/test_regression.py +++ /dev/null @@ -1,147 +0,0 @@ -############################################################################## -# -# Copyright (c) 2005 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Tests for waitress.channel maintenance logic -""" -import doctest - - -class FakeSocket: # pragma: no cover - data = "" - setblocking = lambda *_: None - close = lambda *_: None - - def __init__(self, no): - self.no = no - - def fileno(self): - return self.no - - def getpeername(self): - return ("localhost", self.no) - - def send(self, data): - self.data += data - return len(data) - - def recv(self, data): - return "data" - - -def zombies_test(): - """Regression test for HTTPChannel.maintenance method - - Bug: This method checks for channels that have been "inactive" for a - configured time. The bug was that last_activity is set at creation time - but never updated during async channel activity (reads and writes), so - any channel older than the configured timeout will be closed when a new - channel is created, regardless of activity. - - >>> import time - >>> import waitress.adjustments - >>> config = waitress.adjustments.Adjustments() - - >>> from waitress.server import HTTPServer - >>> class TestServer(HTTPServer): - ... def bind(self, (ip, port)): - ... print "Listening on %s:%d" % (ip or '*', port) - >>> sb = TestServer('127.0.0.1', 80, start=False, verbose=True) - Listening on 127.0.0.1:80 - - First we confirm the correct behavior, where a channel with no activity - for the timeout duration gets closed. - - >>> from waitress.channel import HTTPChannel - >>> socket = FakeSocket(42) - >>> channel = HTTPChannel(sb, socket, ('localhost', 42)) - - >>> channel.connected - True - - >>> channel.last_activity -= int(config.channel_timeout) + 1 - - >>> channel.next_channel_cleanup[0] = channel.creation_time - int( - ... config.cleanup_interval) - 1 - - >>> socket2 = FakeSocket(7) - >>> channel2 = HTTPChannel(sb, socket2, ('localhost', 7)) - - >>> channel.connected - False - - Write Activity - -------------- - - Now we make sure that if there is activity the channel doesn't get closed - incorrectly. - - >>> channel2.connected - True - - >>> channel2.last_activity -= int(config.channel_timeout) + 1 - - >>> channel2.handle_write() - - >>> channel2.next_channel_cleanup[0] = channel2.creation_time - int( - ... config.cleanup_interval) - 1 - - >>> socket3 = FakeSocket(3) - >>> channel3 = HTTPChannel(sb, socket3, ('localhost', 3)) - - >>> channel2.connected - True - - Read Activity - -------------- - - We should test to see that read activity will update a channel as well. - - >>> channel3.connected - True - - >>> channel3.last_activity -= int(config.channel_timeout) + 1 - - >>> import waitress.parser - >>> channel3.parser_class = ( - ... waitress.parser.HTTPRequestParser) - >>> channel3.handle_read() - - >>> channel3.next_channel_cleanup[0] = channel3.creation_time - int( - ... config.cleanup_interval) - 1 - - >>> socket4 = FakeSocket(4) - >>> channel4 = HTTPChannel(sb, socket4, ('localhost', 4)) - - >>> channel3.connected - True - - Main loop window - ---------------- - - There is also a corner case we'll do a shallow test for where a - channel can be closed waiting for the main loop. - - >>> channel4.last_activity -= 1 - - >>> last_active = channel4.last_activity - - >>> channel4.set_async() - - >>> channel4.last_activity != last_active - True - -""" - - -def test_suite(): - return doctest.DocTestSuite() diff --git a/waitress/tests/test_runner.py b/waitress/tests/test_runner.py deleted file mode 100644 index 127757e..0000000 --- a/waitress/tests/test_runner.py +++ /dev/null @@ -1,191 +0,0 @@ -import contextlib -import os -import sys - -if sys.version_info[:2] == (2, 6): # pragma: no cover - import unittest2 as unittest -else: # pragma: no cover - import unittest - -from waitress import runner - - -class Test_match(unittest.TestCase): - def test_empty(self): - self.assertRaisesRegexp( - ValueError, "^Malformed application ''$", runner.match, "" - ) - - def test_module_only(self): - self.assertRaisesRegexp( - ValueError, r"^Malformed application 'foo\.bar'$", runner.match, "foo.bar" - ) - - def test_bad_module(self): - self.assertRaisesRegexp( - ValueError, - r"^Malformed application 'foo#bar:barney'$", - runner.match, - "foo#bar:barney", - ) - - def test_module_obj(self): - self.assertTupleEqual( - runner.match("foo.bar:fred.barney"), ("foo.bar", "fred.barney") - ) - - -class Test_resolve(unittest.TestCase): - def test_bad_module(self): - self.assertRaises( - ImportError, runner.resolve, "nonexistent", "nonexistent_function" - ) - - def test_nonexistent_function(self): - self.assertRaisesRegexp( - AttributeError, - r"has no attribute 'nonexistent_function'", - runner.resolve, - "os.path", - "nonexistent_function", - ) - - def test_simple_happy_path(self): - from os.path import exists - - self.assertIs(runner.resolve("os.path", "exists"), exists) - - def test_complex_happy_path(self): - # Ensure we can recursively resolve object attributes if necessary. - self.assertEquals(runner.resolve("os.path", "exists.__name__"), "exists") - - -class Test_run(unittest.TestCase): - def match_output(self, argv, code, regex): - argv = ["waitress-serve"] + argv - with capture() as captured: - self.assertEqual(runner.run(argv=argv), code) - self.assertRegexpMatches(captured.getvalue(), regex) - captured.close() - - def test_bad(self): - self.match_output(["--bad-opt"], 1, "^Error: option --bad-opt not recognized") - - def test_help(self): - self.match_output(["--help"], 0, "^Usage:\n\n waitress-serve") - - def test_no_app(self): - self.match_output([], 1, "^Error: Specify one application only") - - def test_multiple_apps_app(self): - self.match_output(["a:a", "b:b"], 1, "^Error: Specify one application only") - - def test_bad_apps_app(self): - self.match_output(["a"], 1, "^Error: Malformed application 'a'") - - def test_bad_app_module(self): - self.match_output(["nonexistent:a"], 1, "^Error: Bad module 'nonexistent'") - - self.match_output( - ["nonexistent:a"], - 1, - ( - r"There was an exception \((ImportError|ModuleNotFoundError)\) " - "importing your module.\n\nIt had these arguments: \n" - "1. No module named '?nonexistent'?" - ), - ) - - def test_cwd_added_to_path(self): - def null_serve(app, **kw): - pass - - sys_path = sys.path - current_dir = os.getcwd() - try: - os.chdir(os.path.dirname(__file__)) - argv = [ - "waitress-serve", - "fixtureapps.runner:app", - ] - self.assertEqual(runner.run(argv=argv, _serve=null_serve), 0) - finally: - sys.path = sys_path - os.chdir(current_dir) - - def test_bad_app_object(self): - self.match_output( - ["waitress.tests.fixtureapps.runner:a"], 1, "^Error: Bad object name 'a'" - ) - - def test_simple_call(self): - import waitress.tests.fixtureapps.runner as _apps - - def check_server(app, **kw): - self.assertIs(app, _apps.app) - self.assertDictEqual(kw, {"port": "80"}) - - argv = [ - "waitress-serve", - "--port=80", - "waitress.tests.fixtureapps.runner:app", - ] - self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) - - def test_returned_app(self): - import waitress.tests.fixtureapps.runner as _apps - - def check_server(app, **kw): - self.assertIs(app, _apps.app) - self.assertDictEqual(kw, {"port": "80"}) - - argv = [ - "waitress-serve", - "--port=80", - "--call", - "waitress.tests.fixtureapps.runner:returns_app", - ] - self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) - - -class Test_helper(unittest.TestCase): - def test_exception_logging(self): - from waitress.runner import show_exception - - regex = ( - r"There was an exception \(ImportError\) importing your module." - r"\n\nIt had these arguments: \n1. My reason" - ) - - with capture() as captured: - try: - raise ImportError("My reason") - except ImportError: - self.assertEqual(show_exception(sys.stderr), None) - self.assertRegexpMatches(captured.getvalue(), regex) - captured.close() - - regex = ( - r"There was an exception \(ImportError\) importing your module." - r"\n\nIt had no arguments." - ) - - with capture() as captured: - try: - raise ImportError - except ImportError: - self.assertEqual(show_exception(sys.stderr), None) - self.assertRegexpMatches(captured.getvalue(), regex) - captured.close() - - -@contextlib.contextmanager -def capture(): - from waitress.compat import NativeIO - - fd = NativeIO() - sys.stdout = fd - sys.stderr = fd - yield fd - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py deleted file mode 100644 index 9134fb8..0000000 --- a/waitress/tests/test_server.py +++ /dev/null @@ -1,533 +0,0 @@ -import errno -import socket -import unittest - -dummy_app = object() - - -class TestWSGIServer(unittest.TestCase): - def _makeOne( - self, - application=dummy_app, - host="127.0.0.1", - port=0, - _dispatcher=None, - adj=None, - map=None, - _start=True, - _sock=None, - _server=None, - ): - from waitress.server import create_server - - self.inst = create_server( - application, - host=host, - port=port, - map=map, - _dispatcher=_dispatcher, - _start=_start, - _sock=_sock, - ) - return self.inst - - def _makeOneWithMap( - self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app - ): - sock = DummySock() - task_dispatcher = DummyTaskDispatcher() - map = {} - return self._makeOne( - app, - host=host, - port=port, - map=map, - _sock=sock, - _dispatcher=task_dispatcher, - _start=_start, - ) - - def _makeOneWithMulti( - self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0" - ): - sock = DummySock() - task_dispatcher = DummyTaskDispatcher() - map = {} - from waitress.server import create_server - - self.inst = create_server( - app, - listen=listen, - map=map, - _dispatcher=task_dispatcher, - _start=_start, - _sock=sock, - ) - return self.inst - - def _makeWithSockets( - self, - application=dummy_app, - _dispatcher=None, - map=None, - _start=True, - _sock=None, - _server=None, - sockets=None, - ): - from waitress.server import create_server - - _sockets = [] - if sockets is not None: - _sockets = sockets - self.inst = create_server( - application, - map=map, - _dispatcher=_dispatcher, - _start=_start, - _sock=_sock, - sockets=_sockets, - ) - return self.inst - - def tearDown(self): - if self.inst is not None: - self.inst.close() - - def test_ctor_app_is_None(self): - self.inst = None - self.assertRaises(ValueError, self._makeOneWithMap, app=None) - - def test_ctor_start_true(self): - inst = self._makeOneWithMap(_start=True) - self.assertEqual(inst.accepting, True) - self.assertEqual(inst.socket.listened, 1024) - - def test_ctor_makes_dispatcher(self): - inst = self._makeOne(_start=False, map={}) - self.assertEqual( - inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher" - ) - - def test_ctor_start_false(self): - inst = self._makeOneWithMap(_start=False) - self.assertEqual(inst.accepting, False) - - def test_get_server_name_empty(self): - inst = self._makeOneWithMap(_start=False) - self.assertRaises(ValueError, inst.get_server_name, "") - - def test_get_server_name_with_ip(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("127.0.0.1") - self.assertTrue(result) - - def test_get_server_name_with_hostname(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("fred.flintstone.com") - self.assertEqual(result, "fred.flintstone.com") - - def test_get_server_name_0000(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("0.0.0.0") - self.assertTrue(len(result) != 0) - - def test_get_server_name_double_colon(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("::") - self.assertTrue(len(result) != 0) - - def test_get_server_name_ipv6(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("2001:DB8::ffff") - self.assertEqual("[2001:DB8::ffff]", result) - - def test_get_server_multi(self): - inst = self._makeOneWithMulti() - self.assertEqual(inst.__class__.__name__, "MultiSocketServer") - - def test_run(self): - inst = self._makeOneWithMap(_start=False) - inst.asyncore = DummyAsyncore() - inst.task_dispatcher = DummyTaskDispatcher() - inst.run() - self.assertTrue(inst.task_dispatcher.was_shutdown) - - def test_run_base_server(self): - inst = self._makeOneWithMulti(_start=False) - inst.asyncore = DummyAsyncore() - inst.task_dispatcher = DummyTaskDispatcher() - inst.run() - self.assertTrue(inst.task_dispatcher.was_shutdown) - - def test_pull_trigger(self): - inst = self._makeOneWithMap(_start=False) - inst.trigger.close() - inst.trigger = DummyTrigger() - inst.pull_trigger() - self.assertEqual(inst.trigger.pulled, True) - - def test_add_task(self): - task = DummyTask() - inst = self._makeOneWithMap() - inst.add_task(task) - self.assertEqual(inst.task_dispatcher.tasks, [task]) - self.assertFalse(task.serviced) - - def test_readable_not_accepting(self): - inst = self._makeOneWithMap() - inst.accepting = False - self.assertFalse(inst.readable()) - - def test_readable_maplen_gt_connection_limit(self): - inst = self._makeOneWithMap() - inst.accepting = True - inst.adj = DummyAdj - inst._map = {"a": 1, "b": 2} - self.assertFalse(inst.readable()) - - def test_readable_maplen_lt_connection_limit(self): - inst = self._makeOneWithMap() - inst.accepting = True - inst.adj = DummyAdj - inst._map = {} - self.assertTrue(inst.readable()) - - def test_readable_maintenance_false(self): - import time - - inst = self._makeOneWithMap() - then = time.time() + 1000 - inst.next_channel_cleanup = then - L = [] - inst.maintenance = lambda t: L.append(t) - inst.readable() - self.assertEqual(L, []) - self.assertEqual(inst.next_channel_cleanup, then) - - def test_readable_maintenance_true(self): - inst = self._makeOneWithMap() - inst.next_channel_cleanup = 0 - L = [] - inst.maintenance = lambda t: L.append(t) - inst.readable() - self.assertEqual(len(L), 1) - self.assertNotEqual(inst.next_channel_cleanup, 0) - - def test_writable(self): - inst = self._makeOneWithMap() - self.assertFalse(inst.writable()) - - def test_handle_read(self): - inst = self._makeOneWithMap() - self.assertEqual(inst.handle_read(), None) - - def test_handle_connect(self): - inst = self._makeOneWithMap() - self.assertEqual(inst.handle_connect(), None) - - def test_handle_accept_wouldblock_socket_error(self): - inst = self._makeOneWithMap() - ewouldblock = socket.error(errno.EWOULDBLOCK) - inst.socket = DummySock(toraise=ewouldblock) - inst.handle_accept() - self.assertEqual(inst.socket.accepted, False) - - def test_handle_accept_other_socket_error(self): - inst = self._makeOneWithMap() - eaborted = socket.error(errno.ECONNABORTED) - inst.socket = DummySock(toraise=eaborted) - inst.adj = DummyAdj - - def foo(): - raise socket.error - - inst.accept = foo - inst.logger = DummyLogger() - inst.handle_accept() - self.assertEqual(inst.socket.accepted, False) - self.assertEqual(len(inst.logger.logged), 1) - - def test_handle_accept_noerror(self): - inst = self._makeOneWithMap() - innersock = DummySock() - inst.socket = DummySock(acceptresult=(innersock, None)) - inst.adj = DummyAdj - L = [] - inst.channel_class = lambda *arg, **kw: L.append(arg) - inst.handle_accept() - self.assertEqual(inst.socket.accepted, True) - self.assertEqual(innersock.opts, [("level", "optname", "value")]) - self.assertEqual(L, [(inst, innersock, None, inst.adj)]) - - def test_maintenance(self): - inst = self._makeOneWithMap() - - class DummyChannel(object): - requests = [] - - zombie = DummyChannel() - zombie.last_activity = 0 - zombie.running_tasks = False - inst.active_channels[100] = zombie - inst.maintenance(10000) - self.assertEqual(zombie.will_close, True) - - def test_backward_compatibility(self): - from waitress.server import WSGIServer, TcpWSGIServer - from waitress.adjustments import Adjustments - - self.assertTrue(WSGIServer is TcpWSGIServer) - self.inst = WSGIServer(None, _start=False, port=1234) - # Ensure the adjustment was actually applied. - self.assertNotEqual(Adjustments.port, 1234) - self.assertEqual(self.inst.adj.port, 1234) - - def test_create_with_one_tcp_socket(self): - from waitress.server import TcpWSGIServer - - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - sockets[0].bind(("127.0.0.1", 0)) - inst = self._makeWithSockets(_start=False, sockets=sockets) - self.assertTrue(isinstance(inst, TcpWSGIServer)) - - def test_create_with_multiple_tcp_sockets(self): - from waitress.server import MultiSocketServer - - sockets = [ - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - ] - sockets[0].bind(("127.0.0.1", 0)) - sockets[1].bind(("127.0.0.1", 0)) - inst = self._makeWithSockets(_start=False, sockets=sockets) - self.assertTrue(isinstance(inst, MultiSocketServer)) - self.assertEqual(len(inst.effective_listen), 2) - - def test_create_with_one_socket_should_not_bind_socket(self): - innersock = DummySock() - sockets = [DummySock(acceptresult=(innersock, None))] - sockets[0].bind(("127.0.0.1", 80)) - sockets[0].bind_called = False - inst = self._makeWithSockets(_start=False, sockets=sockets) - self.assertEqual(inst.socket.bound, ("127.0.0.1", 80)) - self.assertFalse(inst.socket.bind_called) - - def test_create_with_one_socket_handle_accept_noerror(self): - innersock = DummySock() - sockets = [DummySock(acceptresult=(innersock, None))] - sockets[0].bind(("127.0.0.1", 80)) - inst = self._makeWithSockets(sockets=sockets) - L = [] - inst.channel_class = lambda *arg, **kw: L.append(arg) - inst.adj = DummyAdj - inst.handle_accept() - self.assertEqual(sockets[0].accepted, True) - self.assertEqual(innersock.opts, [("level", "optname", "value")]) - self.assertEqual(L, [(inst, innersock, None, inst.adj)]) - - -if hasattr(socket, "AF_UNIX"): - - class TestUnixWSGIServer(unittest.TestCase): - unix_socket = "/tmp/waitress.test.sock" - - def _makeOne(self, _start=True, _sock=None): - from waitress.server import create_server - - self.inst = create_server( - dummy_app, - map={}, - _start=_start, - _sock=_sock, - _dispatcher=DummyTaskDispatcher(), - unix_socket=self.unix_socket, - unix_socket_perms="600", - ) - return self.inst - - def _makeWithSockets( - self, - application=dummy_app, - _dispatcher=None, - map=None, - _start=True, - _sock=None, - _server=None, - sockets=None, - ): - from waitress.server import create_server - - _sockets = [] - if sockets is not None: - _sockets = sockets - self.inst = create_server( - application, - map=map, - _dispatcher=_dispatcher, - _start=_start, - _sock=_sock, - sockets=_sockets, - ) - return self.inst - - def tearDown(self): - self.inst.close() - - def _makeDummy(self, *args, **kwargs): - sock = DummySock(*args, **kwargs) - sock.family = socket.AF_UNIX - return sock - - def test_unix(self): - inst = self._makeOne(_start=False) - self.assertEqual(inst.socket.family, socket.AF_UNIX) - self.assertEqual(inst.socket.getsockname(), self.unix_socket) - - def test_handle_accept(self): - # Working on the assumption that we only have to test the happy path - # for Unix domain sockets as the other paths should've been covered - # by inet sockets. - client = self._makeDummy() - listen = self._makeDummy(acceptresult=(client, None)) - inst = self._makeOne(_sock=listen) - self.assertEqual(inst.accepting, True) - self.assertEqual(inst.socket.listened, 1024) - L = [] - inst.channel_class = lambda *arg, **kw: L.append(arg) - inst.handle_accept() - self.assertEqual(inst.socket.accepted, True) - self.assertEqual(client.opts, []) - self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) - - def test_creates_new_sockinfo(self): - from waitress.server import UnixWSGIServer - - self.inst = UnixWSGIServer( - dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600" - ) - - self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) - - def test_create_with_unix_socket(self): - from waitress.server import ( - MultiSocketServer, - BaseWSGIServer, - TcpWSGIServer, - UnixWSGIServer, - ) - - sockets = [ - socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - ] - inst = self._makeWithSockets(sockets=sockets, _start=False) - self.assertTrue(isinstance(inst, MultiSocketServer)) - server = list( - filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) - ) - self.assertTrue(isinstance(server[0], UnixWSGIServer)) - self.assertTrue(isinstance(server[1], UnixWSGIServer)) - - -class DummySock(socket.socket): - accepted = False - blocking = False - family = socket.AF_INET - type = socket.SOCK_STREAM - proto = 0 - - def __init__(self, toraise=None, acceptresult=(None, None)): - self.toraise = toraise - self.acceptresult = acceptresult - self.bound = None - self.opts = [] - self.bind_called = False - - def bind(self, addr): - self.bind_called = True - self.bound = addr - - def accept(self): - if self.toraise: - raise self.toraise - self.accepted = True - return self.acceptresult - - def setblocking(self, x): - self.blocking = True - - def fileno(self): - return 10 - - def getpeername(self): - return "127.0.0.1" - - def setsockopt(self, *arg): - self.opts.append(arg) - - def getsockopt(self, *arg): - return 1 - - def listen(self, num): - self.listened = num - - def getsockname(self): - return self.bound - - def close(self): - pass - - -class DummyTaskDispatcher(object): - def __init__(self): - self.tasks = [] - - def add_task(self, task): - self.tasks.append(task) - - def shutdown(self): - self.was_shutdown = True - - -class DummyTask(object): - serviced = False - start_response_called = False - wrote_header = False - status = "200 OK" - - def __init__(self): - self.response_headers = {} - self.written = "" - - def service(self): # pragma: no cover - self.serviced = True - - -class DummyAdj: - connection_limit = 1 - log_socket_errors = True - socket_options = [("level", "optname", "value")] - cleanup_interval = 900 - channel_timeout = 300 - - -class DummyAsyncore(object): - def loop(self, timeout=30.0, use_poll=False, map=None, count=None): - raise SystemExit - - -class DummyTrigger(object): - def pull_trigger(self): - self.pulled = True - - def close(self): - pass - - -class DummyLogger(object): - def __init__(self): - self.logged = [] - - def warning(self, msg, **kw): - self.logged.append(msg) diff --git a/waitress/tests/test_task.py b/waitress/tests/test_task.py deleted file mode 100644 index 1a86245..0000000 --- a/waitress/tests/test_task.py +++ /dev/null @@ -1,1001 +0,0 @@ -import unittest -import io - - -class TestThreadedTaskDispatcher(unittest.TestCase): - def _makeOne(self): - from waitress.task import ThreadedTaskDispatcher - - return ThreadedTaskDispatcher() - - def test_handler_thread_task_raises(self): - inst = self._makeOne() - inst.threads.add(0) - inst.logger = DummyLogger() - - class BadDummyTask(DummyTask): - def service(self): - super(BadDummyTask, self).service() - inst.stop_count += 1 - raise Exception - - task = BadDummyTask() - inst.logger = DummyLogger() - inst.queue.append(task) - inst.active_count += 1 - inst.handler_thread(0) - self.assertEqual(inst.stop_count, 0) - self.assertEqual(inst.active_count, 0) - self.assertEqual(inst.threads, set()) - self.assertEqual(len(inst.logger.logged), 1) - - def test_set_thread_count_increase(self): - inst = self._makeOne() - L = [] - inst.start_new_thread = lambda *x: L.append(x) - inst.set_thread_count(1) - self.assertEqual(L, [(inst.handler_thread, (0,))]) - - def test_set_thread_count_increase_with_existing(self): - inst = self._makeOne() - L = [] - inst.threads = {0} - inst.start_new_thread = lambda *x: L.append(x) - inst.set_thread_count(2) - self.assertEqual(L, [(inst.handler_thread, (1,))]) - - def test_set_thread_count_decrease(self): - inst = self._makeOne() - inst.threads = {0, 1} - inst.set_thread_count(1) - self.assertEqual(inst.stop_count, 1) - - def test_set_thread_count_same(self): - inst = self._makeOne() - L = [] - inst.start_new_thread = lambda *x: L.append(x) - inst.threads = {0} - inst.set_thread_count(1) - self.assertEqual(L, []) - - def test_add_task_with_idle_threads(self): - task = DummyTask() - inst = self._makeOne() - inst.threads.add(0) - inst.queue_logger = DummyLogger() - inst.add_task(task) - self.assertEqual(len(inst.queue), 1) - self.assertEqual(len(inst.queue_logger.logged), 0) - - def test_add_task_with_all_busy_threads(self): - task = DummyTask() - inst = self._makeOne() - inst.queue_logger = DummyLogger() - inst.add_task(task) - self.assertEqual(len(inst.queue_logger.logged), 1) - inst.add_task(task) - self.assertEqual(len(inst.queue_logger.logged), 2) - - def test_shutdown_one_thread(self): - inst = self._makeOne() - inst.threads.add(0) - inst.logger = DummyLogger() - task = DummyTask() - inst.queue.append(task) - self.assertEqual(inst.shutdown(timeout=0.01), True) - self.assertEqual( - inst.logger.logged, - ["1 thread(s) still running", "Canceling 1 pending task(s)",], - ) - self.assertEqual(task.cancelled, True) - - def test_shutdown_no_threads(self): - inst = self._makeOne() - self.assertEqual(inst.shutdown(timeout=0.01), True) - - def test_shutdown_no_cancel_pending(self): - inst = self._makeOne() - self.assertEqual(inst.shutdown(cancel_pending=False, timeout=0.01), False) - - -class TestTask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): - if channel is None: - channel = DummyChannel() - if request is None: - request = DummyParser() - from waitress.task import Task - - return Task(channel, request) - - def test_ctor_version_not_in_known(self): - request = DummyParser() - request.version = "8.4" - inst = self._makeOne(request=request) - self.assertEqual(inst.version, "1.0") - - def test_build_response_header_bad_http_version(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "8.4" - self.assertRaises(AssertionError, inst.build_response_header) - - def test_build_response_header_v10_keepalive_no_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.request.headers["CONNECTION"] = "keep-alive" - inst.version = "1.0" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v10_keepalive_with_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.request.headers["CONNECTION"] = "keep-alive" - inst.response_headers = [("Content-Length", "10")] - inst.version = "1.0" - inst.content_length = 0 - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: Keep-Alive") - self.assertEqual(lines[2], b"Content-Length: 10") - self.assertTrue(lines[3].startswith(b"Date:")) - self.assertEqual(lines[4], b"Server: waitress") - self.assertEqual(inst.close_on_finish, False) - - def test_build_response_header_v11_connection_closed_by_client(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.request.headers["CONNECTION"] = "close" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(lines[4], b"Transfer-Encoding: chunked") - self.assertTrue(("Connection", "close") in inst.response_headers) - self.assertEqual(inst.close_on_finish, True) - - def test_build_response_header_v11_connection_keepalive_by_client(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.request.headers["CONNECTION"] = "keep-alive" - inst.version = "1.1" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(lines[4], b"Transfer-Encoding: chunked") - self.assertTrue(("Connection", "close") in inst.response_headers) - self.assertEqual(inst.close_on_finish, True) - - def test_build_response_header_v11_200_no_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(lines[4], b"Transfer-Encoding: chunked") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v11_204_no_content_length_or_transfer_encoding(self): - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx or 204. - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.status = "204 No Content" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 204 No Content") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(self): - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx or 204. - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.status = "100 Continue" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 100 Continue") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self): - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx, 204 or 304. - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.status = "304 Not Modified" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 304 Not Modified") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_via_added(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.0" - inst.response_headers = [("Server", "abc")] - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: abc") - self.assertEqual(lines[4], b"Via: waitress") - - def test_build_response_header_date_exists(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.0" - inst.response_headers = [("Date", "date")] - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - - def test_build_response_header_preexisting_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.content_length = 100 - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Content-Length: 100") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - - def test_remove_content_length_header(self): - inst = self._makeOne() - inst.response_headers = [("Content-Length", "70")] - inst.remove_content_length_header() - self.assertEqual(inst.response_headers, []) - - def test_remove_content_length_header_with_other(self): - inst = self._makeOne() - inst.response_headers = [ - ("Content-Length", "70"), - ("Content-Type", "text/html"), - ] - inst.remove_content_length_header() - self.assertEqual(inst.response_headers, [("Content-Type", "text/html")]) - - def test_start(self): - inst = self._makeOne() - inst.start() - self.assertTrue(inst.start_time) - - def test_finish_didnt_write_header(self): - inst = self._makeOne() - inst.wrote_header = False - inst.complete = True - inst.finish() - self.assertTrue(inst.channel.written) - - def test_finish_wrote_header(self): - inst = self._makeOne() - inst.wrote_header = True - inst.finish() - self.assertFalse(inst.channel.written) - - def test_finish_chunked_response(self): - inst = self._makeOne() - inst.wrote_header = True - inst.chunked_response = True - inst.finish() - self.assertEqual(inst.channel.written, b"0\r\n\r\n") - - def test_write_wrote_header(self): - inst = self._makeOne() - inst.wrote_header = True - inst.complete = True - inst.content_length = 3 - inst.write(b"abc") - self.assertEqual(inst.channel.written, b"abc") - - def test_write_header_not_written(self): - inst = self._makeOne() - inst.wrote_header = False - inst.complete = True - inst.write(b"abc") - self.assertTrue(inst.channel.written) - self.assertEqual(inst.wrote_header, True) - - def test_write_start_response_uncalled(self): - inst = self._makeOne() - self.assertRaises(RuntimeError, inst.write, b"") - - def test_write_chunked_response(self): - inst = self._makeOne() - inst.wrote_header = True - inst.chunked_response = True - inst.complete = True - inst.write(b"abc") - self.assertEqual(inst.channel.written, b"3\r\nabc\r\n") - - def test_write_preexisting_content_length(self): - inst = self._makeOne() - inst.wrote_header = True - inst.complete = True - inst.content_length = 1 - inst.logger = DummyLogger() - inst.write(b"abc") - self.assertTrue(inst.channel.written) - self.assertEqual(inst.logged_write_excess, True) - self.assertEqual(len(inst.logger.logged), 1) - - -class TestWSGITask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): - if channel is None: - channel = DummyChannel() - if request is None: - request = DummyParser() - from waitress.task import WSGITask - - return WSGITask(channel, request) - - def test_service(self): - inst = self._makeOne() - - def execute(): - inst.executed = True - - inst.execute = execute - inst.complete = True - inst.service() - self.assertTrue(inst.start_time) - self.assertTrue(inst.close_on_finish) - self.assertTrue(inst.channel.written) - self.assertEqual(inst.executed, True) - - def test_service_server_raises_socket_error(self): - import socket - - inst = self._makeOne() - - def execute(): - raise socket.error - - inst.execute = execute - self.assertRaises(socket.error, inst.service) - self.assertTrue(inst.start_time) - self.assertTrue(inst.close_on_finish) - self.assertFalse(inst.channel.written) - - def test_execute_app_calls_start_response_twice_wo_exc_info(self): - def app(environ, start_response): - start_response("200 OK", []) - start_response("200 OK", []) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_app_calls_start_response_w_exc_info_complete(self): - def app(environ, start_response): - start_response("200 OK", [], [ValueError, ValueError(), None]) - return [b"a"] - - inst = self._makeOne() - inst.complete = True - inst.channel.server.application = app - inst.execute() - self.assertTrue(inst.complete) - self.assertEqual(inst.status, "200 OK") - self.assertTrue(inst.channel.written) - - def test_execute_app_calls_start_response_w_excinf_headers_unwritten(self): - def app(environ, start_response): - start_response("200 OK", [], [ValueError, None, None]) - return [b"a"] - - inst = self._makeOne() - inst.wrote_header = False - inst.channel.server.application = app - inst.response_headers = [("a", "b")] - inst.execute() - self.assertTrue(inst.complete) - self.assertEqual(inst.status, "200 OK") - self.assertTrue(inst.channel.written) - self.assertFalse(("a", "b") in inst.response_headers) - - def test_execute_app_calls_start_response_w_excinf_headers_written(self): - def app(environ, start_response): - start_response("200 OK", [], [ValueError, ValueError(), None]) - - inst = self._makeOne() - inst.complete = True - inst.wrote_header = True - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_execute_bad_header_key(self): - def app(environ, start_response): - start_response("200 OK", [(None, "a")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_bad_header_value(self): - def app(environ, start_response): - start_response("200 OK", [("a", None)]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_hopbyhop_header(self): - def app(environ, start_response): - start_response("200 OK", [("Connection", "close")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_bad_header_value_control_characters(self): - def app(environ, start_response): - start_response("200 OK", [("a", "\n")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_execute_bad_header_name_control_characters(self): - def app(environ, start_response): - start_response("200 OK", [("a\r", "value")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_execute_bad_status_control_characters(self): - def app(environ, start_response): - start_response("200 OK\r", []) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_preserve_header_value_order(self): - def app(environ, start_response): - write = start_response("200 OK", [("C", "b"), ("A", "b"), ("A", "a")]) - write(b"abc") - return [] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertTrue(b"A: b\r\nA: a\r\nC: b\r\n" in inst.channel.written) - - def test_execute_bad_status_value(self): - def app(environ, start_response): - start_response(None, []) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_with_content_length_header(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "1")]) - return [b"a"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.content_length, 1) - - def test_execute_app_calls_write(self): - def app(environ, start_response): - write = start_response("200 OK", [("Content-Length", "3")]) - write(b"abc") - return [] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.channel.written[-3:], b"abc") - - def test_execute_app_returns_len1_chunk_without_cl(self): - def app(environ, start_response): - start_response("200 OK", []) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.content_length, 3) - - def test_execute_app_returns_empty_chunk_as_first(self): - def app(environ, start_response): - start_response("200 OK", []) - return ["", b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.content_length, None) - - def test_execute_app_returns_too_many_bytes(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "1")]) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_returns_too_few_bytes(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return [b"a"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_do_not_warn_on_head(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return [b""] - - inst = self._makeOne() - inst.request.command = "HEAD" - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertEqual(len(inst.logger.logged), 0) - - def test_execute_app_without_body_204_logged(self): - def app(environ, start_response): - start_response("204 No Content", [("Content-Length", "3")]) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertNotIn(b"abc", inst.channel.written) - self.assertNotIn(b"Content-Length", inst.channel.written) - self.assertNotIn(b"Transfer-Encoding", inst.channel.written) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_without_body_304_logged(self): - def app(environ, start_response): - start_response("304 Not Modified", [("Content-Length", "3")]) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertNotIn(b"abc", inst.channel.written) - self.assertNotIn(b"Content-Length", inst.channel.written) - self.assertNotIn(b"Transfer-Encoding", inst.channel.written) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_returns_closeable(self): - class closeable(list): - def close(self): - self.closed = True - - foo = closeable([b"abc"]) - - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return foo - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(foo.closed, True) - - def test_execute_app_returns_filewrapper_prepare_returns_True(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - app_iter = ReadOnlyFileBasedBuffer(f, 8192) - - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return app_iter - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertTrue(inst.channel.written) # header - self.assertEqual(inst.channel.otherdata, [app_iter]) - - def test_execute_app_returns_filewrapper_prepare_returns_True_nocl(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - app_iter = ReadOnlyFileBasedBuffer(f, 8192) - - def app(environ, start_response): - start_response("200 OK", []) - return app_iter - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertTrue(inst.channel.written) # header - self.assertEqual(inst.channel.otherdata, [app_iter]) - self.assertEqual(inst.content_length, 3) - - def test_execute_app_returns_filewrapper_prepare_returns_True_badcl(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - app_iter = ReadOnlyFileBasedBuffer(f, 8192) - - def app(environ, start_response): - start_response("200 OK", []) - return app_iter - - inst = self._makeOne() - inst.channel.server.application = app - inst.content_length = 10 - inst.response_headers = [("Content-Length", "10")] - inst.execute() - self.assertTrue(inst.channel.written) # header - self.assertEqual(inst.channel.otherdata, [app_iter]) - self.assertEqual(inst.content_length, 3) - self.assertEqual(dict(inst.response_headers)["Content-Length"], "3") - - def test_get_environment_already_cached(self): - inst = self._makeOne() - inst.environ = object() - self.assertEqual(inst.get_environment(), inst.environ) - - def test_get_environment_path_startswith_more_than_one_slash(self): - inst = self._makeOne() - request = DummyParser() - request.path = "///abc" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "/abc") - - def test_get_environment_path_empty(self): - inst = self._makeOne() - request = DummyParser() - request.path = "" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "") - - def test_get_environment_no_query(self): - inst = self._makeOne() - request = DummyParser() - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["QUERY_STRING"], "") - - def test_get_environment_with_query(self): - inst = self._makeOne() - request = DummyParser() - request.query = "abc" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["QUERY_STRING"], "abc") - - def test_get_environ_with_url_prefix_miss(self): - inst = self._makeOne() - inst.channel.server.adj.url_prefix = "/foo" - request = DummyParser() - request.path = "/bar" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "/bar") - self.assertEqual(environ["SCRIPT_NAME"], "/foo") - - def test_get_environ_with_url_prefix_hit(self): - inst = self._makeOne() - inst.channel.server.adj.url_prefix = "/foo" - request = DummyParser() - request.path = "/foo/fuz" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "/fuz") - self.assertEqual(environ["SCRIPT_NAME"], "/foo") - - def test_get_environ_with_url_prefix_empty_path(self): - inst = self._makeOne() - inst.channel.server.adj.url_prefix = "/foo" - request = DummyParser() - request.path = "/foo" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "") - self.assertEqual(environ["SCRIPT_NAME"], "/foo") - - def test_get_environment_values(self): - import sys - - inst = self._makeOne() - request = DummyParser() - request.headers = { - "CONTENT_TYPE": "abc", - "CONTENT_LENGTH": "10", - "X_FOO": "BAR", - "CONNECTION": "close", - } - request.query = "abc" - inst.request = request - environ = inst.get_environment() - - # nail the keys of environ - self.assertEqual( - sorted(environ.keys()), - [ - "CONTENT_LENGTH", - "CONTENT_TYPE", - "HTTP_CONNECTION", - "HTTP_X_FOO", - "PATH_INFO", - "QUERY_STRING", - "REMOTE_ADDR", - "REMOTE_HOST", - "REMOTE_PORT", - "REQUEST_METHOD", - "SCRIPT_NAME", - "SERVER_NAME", - "SERVER_PORT", - "SERVER_PROTOCOL", - "SERVER_SOFTWARE", - "wsgi.errors", - "wsgi.file_wrapper", - "wsgi.input", - "wsgi.input_terminated", - "wsgi.multiprocess", - "wsgi.multithread", - "wsgi.run_once", - "wsgi.url_scheme", - "wsgi.version", - ], - ) - - self.assertEqual(environ["REQUEST_METHOD"], "GET") - self.assertEqual(environ["SERVER_PORT"], "80") - self.assertEqual(environ["SERVER_NAME"], "localhost") - self.assertEqual(environ["SERVER_SOFTWARE"], "waitress") - self.assertEqual(environ["SERVER_PROTOCOL"], "HTTP/1.0") - self.assertEqual(environ["SCRIPT_NAME"], "") - self.assertEqual(environ["HTTP_CONNECTION"], "close") - self.assertEqual(environ["PATH_INFO"], "/") - self.assertEqual(environ["QUERY_STRING"], "abc") - self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") - self.assertEqual(environ["REMOTE_HOST"], "127.0.0.1") - self.assertEqual(environ["REMOTE_PORT"], "39830") - self.assertEqual(environ["CONTENT_TYPE"], "abc") - self.assertEqual(environ["CONTENT_LENGTH"], "10") - self.assertEqual(environ["HTTP_X_FOO"], "BAR") - self.assertEqual(environ["wsgi.version"], (1, 0)) - self.assertEqual(environ["wsgi.url_scheme"], "http") - self.assertEqual(environ["wsgi.errors"], sys.stderr) - self.assertEqual(environ["wsgi.multithread"], True) - self.assertEqual(environ["wsgi.multiprocess"], False) - self.assertEqual(environ["wsgi.run_once"], False) - self.assertEqual(environ["wsgi.input"], "stream") - self.assertEqual(environ["wsgi.input_terminated"], True) - self.assertEqual(inst.environ, environ) - - -class TestErrorTask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): - if channel is None: - channel = DummyChannel() - if request is None: - request = DummyParser() - request.error = self._makeDummyError() - from waitress.task import ErrorTask - - return ErrorTask(channel, request) - - def _makeDummyError(self): - from waitress.utilities import Error - - e = Error("body") - e.code = 432 - e.reason = "Too Ugly" - return e - - def test_execute_http_10(self): - inst = self._makeOne() - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.0 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - def test_execute_http_11(self): - inst = self._makeOne() - inst.version = "1.1" - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - def test_execute_http_11_close(self): - inst = self._makeOne() - inst.version = "1.1" - inst.request.headers["CONNECTION"] = "close" - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - def test_execute_http_11_keep_forces_close(self): - inst = self._makeOne() - inst.version = "1.1" - inst.request.headers["CONNECTION"] = "keep-alive" - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - -class DummyTask(object): - serviced = False - cancelled = False - - def service(self): - self.serviced = True - - def cancel(self): - self.cancelled = True - - -class DummyAdj(object): - log_socket_errors = True - ident = "waitress" - host = "127.0.0.1" - port = 80 - url_prefix = "" - - -class DummyServer(object): - server_name = "localhost" - effective_port = 80 - - def __init__(self): - self.adj = DummyAdj() - - -class DummyChannel(object): - closed_when_done = False - adj = DummyAdj() - creation_time = 0 - addr = ("127.0.0.1", 39830) - - def __init__(self, server=None): - if server is None: - server = DummyServer() - self.server = server - self.written = b"" - self.otherdata = [] - - def write_soon(self, data): - if isinstance(data, bytes): - self.written += data - else: - self.otherdata.append(data) - return len(data) - - -class DummyParser(object): - version = "1.0" - command = "GET" - path = "/" - query = "" - url_scheme = "http" - expect_continue = False - headers_finished = False - - def __init__(self): - self.headers = {} - - def get_body_stream(self): - return "stream" - - -def filter_lines(s): - return list(filter(None, s.split(b"\r\n"))) - - -class DummyLogger(object): - def __init__(self): - self.logged = [] - - def warning(self, msg, *args): - self.logged.append(msg % args) - - def exception(self, msg, *args): - self.logged.append(msg % args) diff --git a/waitress/tests/test_trigger.py b/waitress/tests/test_trigger.py deleted file mode 100644 index af740f6..0000000 --- a/waitress/tests/test_trigger.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest -import os -import sys - -if not sys.platform.startswith("win"): - - class Test_trigger(unittest.TestCase): - def _makeOne(self, map): - from waitress.trigger import trigger - - self.inst = trigger(map) - return self.inst - - def tearDown(self): - self.inst.close() # prevent __del__ warning from file_dispatcher - - def test__close(self): - map = {} - inst = self._makeOne(map) - fd1, fd2 = inst._fds - inst.close() - self.assertRaises(OSError, os.read, fd1, 1) - self.assertRaises(OSError, os.read, fd2, 1) - - def test__physical_pull(self): - map = {} - inst = self._makeOne(map) - inst._physical_pull() - r = os.read(inst._fds[0], 1) - self.assertEqual(r, b"x") - - def test_readable(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.readable(), True) - - def test_writable(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.writable(), False) - - def test_handle_connect(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.handle_connect(), None) - - def test_close(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.close(), None) - self.assertEqual(inst._closed, True) - - def test_handle_close(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.handle_close(), None) - self.assertEqual(inst._closed, True) - - def test_pull_trigger_nothunk(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.pull_trigger(), None) - r = os.read(inst._fds[0], 1) - self.assertEqual(r, b"x") - - def test_pull_trigger_thunk(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.pull_trigger(True), None) - self.assertEqual(len(inst.thunks), 1) - r = os.read(inst._fds[0], 1) - self.assertEqual(r, b"x") - - def test_handle_read_socket_error(self): - map = {} - inst = self._makeOne(map) - result = inst.handle_read() - self.assertEqual(result, None) - - def test_handle_read_no_socket_error(self): - map = {} - inst = self._makeOne(map) - inst.pull_trigger() - result = inst.handle_read() - self.assertEqual(result, None) - - def test_handle_read_thunk(self): - map = {} - inst = self._makeOne(map) - inst.pull_trigger() - L = [] - inst.thunks = [lambda: L.append(True)] - result = inst.handle_read() - self.assertEqual(result, None) - self.assertEqual(L, [True]) - self.assertEqual(inst.thunks, []) - - def test_handle_read_thunk_error(self): - map = {} - inst = self._makeOne(map) - - def errorthunk(): - raise ValueError - - inst.pull_trigger(errorthunk) - L = [] - inst.log_info = lambda *arg: L.append(arg) - result = inst.handle_read() - self.assertEqual(result, None) - self.assertEqual(len(L), 1) - self.assertEqual(inst.thunks, []) diff --git a/waitress/tests/test_utilities.py b/waitress/tests/test_utilities.py deleted file mode 100644 index 15cd24f..0000000 --- a/waitress/tests/test_utilities.py +++ /dev/null @@ -1,140 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import unittest - - -class Test_parse_http_date(unittest.TestCase): - def _callFUT(self, v): - from waitress.utilities import parse_http_date - - return parse_http_date(v) - - def test_rfc850(self): - val = "Tuesday, 08-Feb-94 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, 760716929) - - def test_rfc822(self): - val = "Sun, 08 Feb 1994 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, 760716929) - - def test_neither(self): - val = "" - result = self._callFUT(val) - self.assertEqual(result, 0) - - -class Test_build_http_date(unittest.TestCase): - def test_rountdrip(self): - from waitress.utilities import build_http_date, parse_http_date - from time import time - - t = int(time()) - self.assertEqual(t, parse_http_date(build_http_date(t))) - - -class Test_unpack_rfc850(unittest.TestCase): - def _callFUT(self, val): - from waitress.utilities import unpack_rfc850, rfc850_reg - - return unpack_rfc850(rfc850_reg.match(val.lower())) - - def test_it(self): - val = "Tuesday, 08-Feb-94 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) - - -class Test_unpack_rfc_822(unittest.TestCase): - def _callFUT(self, val): - from waitress.utilities import unpack_rfc822, rfc822_reg - - return unpack_rfc822(rfc822_reg.match(val.lower())) - - def test_it(self): - val = "Sun, 08 Feb 1994 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) - - -class Test_find_double_newline(unittest.TestCase): - def _callFUT(self, val): - from waitress.utilities import find_double_newline - - return find_double_newline(val) - - def test_empty(self): - self.assertEqual(self._callFUT(b""), -1) - - def test_one_linefeed(self): - self.assertEqual(self._callFUT(b"\n"), -1) - - def test_double_linefeed(self): - self.assertEqual(self._callFUT(b"\n\n"), -1) - - def test_one_crlf(self): - self.assertEqual(self._callFUT(b"\r\n"), -1) - - def test_double_crfl(self): - self.assertEqual(self._callFUT(b"\r\n\r\n"), 4) - - def test_mixed(self): - self.assertEqual(self._callFUT(b"\n\n00\r\n\r\n"), 8) - - -class TestBadRequest(unittest.TestCase): - def _makeOne(self): - from waitress.utilities import BadRequest - - return BadRequest(1) - - def test_it(self): - inst = self._makeOne() - self.assertEqual(inst.body, 1) - - -class Test_undquote(unittest.TestCase): - def _callFUT(self, value): - from waitress.utilities import undquote - - return undquote(value) - - def test_empty(self): - self.assertEqual(self._callFUT(""), "") - - def test_quoted(self): - self.assertEqual(self._callFUT('"test"'), "test") - - def test_unquoted(self): - self.assertEqual(self._callFUT("test"), "test") - - def test_quoted_backslash_quote(self): - self.assertEqual(self._callFUT('"\\""'), '"') - - def test_quoted_htab(self): - self.assertEqual(self._callFUT('"\t"'), "\t") - - def test_quoted_backslash_htab(self): - self.assertEqual(self._callFUT('"\\\t"'), "\t") - - def test_quoted_backslash_invalid(self): - self.assertRaises(ValueError, self._callFUT, '"\\"') - - def test_invalid_quoting(self): - self.assertRaises(ValueError, self._callFUT, '"test') - - def test_invalid_quoting_single_quote(self): - self.assertRaises(ValueError, self._callFUT, '"') diff --git a/waitress/tests/test_wasyncore.py b/waitress/tests/test_wasyncore.py deleted file mode 100644 index 9c23509..0000000 --- a/waitress/tests/test_wasyncore.py +++ /dev/null @@ -1,1761 +0,0 @@ -from waitress import wasyncore as asyncore -from waitress import compat -import contextlib -import functools -import gc -import unittest -import select -import os -import socket -import sys -import time -import errno -import re -import struct -import threading -import warnings - -from io import BytesIO - -TIMEOUT = 3 -HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") -HOST = "localhost" -HOSTv4 = "127.0.0.1" -HOSTv6 = "::1" - -# Filename used for testing -if os.name == "java": # pragma: no cover - # Jython disallows @ in module names - TESTFN = "$test" -else: - TESTFN = "@test" - -TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid()) - - -class DummyLogger(object): # pragma: no cover - def __init__(self): - self.messages = [] - - def log(self, severity, message): - self.messages.append((severity, message)) - - -class WarningsRecorder(object): # pragma: no cover - """Convenience wrapper for the warnings list returned on - entry to the warnings.catch_warnings() context manager. - """ - - def __init__(self, warnings_list): - self._warnings = warnings_list - self._last = 0 - - @property - def warnings(self): - return self._warnings[self._last :] - - def reset(self): - self._last = len(self._warnings) - - -def _filterwarnings(filters, quiet=False): # pragma: no cover - """Catch the warnings, then check if all the expected - warnings have been raised and re-raise unexpected warnings. - If 'quiet' is True, only re-raise the unexpected warnings. - """ - # Clear the warning registry of the calling module - # in order to re-raise the warnings. - frame = sys._getframe(2) - registry = frame.f_globals.get("__warningregistry__") - if registry: - registry.clear() - with warnings.catch_warnings(record=True) as w: - # Set filter "always" to record all warnings. Because - # test_warnings swap the module, we need to look up in - # the sys.modules dictionary. - sys.modules["warnings"].simplefilter("always") - yield WarningsRecorder(w) - # Filter the recorded warnings - reraise = list(w) - missing = [] - for msg, cat in filters: - seen = False - for w in reraise[:]: - warning = w.message - # Filter out the matching messages - if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat): - seen = True - reraise.remove(w) - if not seen and not quiet: - # This filter caught nothing - missing.append((msg, cat.__name__)) - if reraise: - raise AssertionError("unhandled warning %s" % reraise[0]) - if missing: - raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0]) - - -@contextlib.contextmanager -def check_warnings(*filters, **kwargs): # pragma: no cover - """Context manager to silence warnings. - - Accept 2-tuples as positional arguments: - ("message regexp", WarningCategory) - - Optional argument: - - if 'quiet' is True, it does not fail if a filter catches nothing - (default True without argument, - default False if some filters are defined) - - Without argument, it defaults to: - check_warnings(("", Warning), quiet=True) - """ - quiet = kwargs.get("quiet") - if not filters: - filters = (("", Warning),) - # Preserve backward compatibility - if quiet is None: - quiet = True - return _filterwarnings(filters, quiet) - - -def gc_collect(): # pragma: no cover - """Force as many objects as possible to be collected. - - In non-CPython implementations of Python, this is needed because timely - deallocation is not guaranteed by the garbage collector. (Even in CPython - this can be the case in case of reference cycles.) This means that __del__ - methods may be called later than expected and weakrefs may remain alive for - longer than expected. This function tries its best to force all garbage - objects to disappear. - """ - gc.collect() - if sys.platform.startswith("java"): - time.sleep(0.1) - gc.collect() - gc.collect() - - -def threading_setup(): # pragma: no cover - return (compat.thread._count(), None) - - -def threading_cleanup(*original_values): # pragma: no cover - global environment_altered - - _MAX_COUNT = 100 - - for count in range(_MAX_COUNT): - values = (compat.thread._count(), None) - if values == original_values: - break - - if not count: - # Display a warning at the first iteration - environment_altered = True - sys.stderr.write( - "Warning -- threading_cleanup() failed to cleanup " - "%s threads" % (values[0] - original_values[0]) - ) - sys.stderr.flush() - - values = None - - time.sleep(0.01) - gc_collect() - - -def reap_threads(func): # pragma: no cover - """Use this function when threads are being used. This will - ensure that the threads are cleaned up even when the test fails. - """ - - @functools.wraps(func) - def decorator(*args): - key = threading_setup() - try: - return func(*args) - finally: - threading_cleanup(*key) - - return decorator - - -def join_thread(thread, timeout=30.0): # pragma: no cover - """Join a thread. Raise an AssertionError if the thread is still alive - after timeout seconds. - """ - thread.join(timeout) - if thread.is_alive(): - msg = "failed to join the thread in %.1f seconds" % timeout - raise AssertionError(msg) - - -def bind_port(sock, host=HOST): # pragma: no cover - """Bind the socket to a free port and return the port number. Relies on - ephemeral ports in order to ensure we are using an unbound port. This is - important as many tests may be running simultaneously, especially in a - buildbot environment. This method raises an exception if the sock.family - is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR - or SO_REUSEPORT set on it. Tests should *never* set these socket options - for TCP/IP sockets. The only case for setting these options is testing - multicasting via multiple UDP sockets. - - Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. - on Windows), it will be set on the socket. This will prevent anyone else - from bind()'ing to our host/port for the duration of the test. - """ - - if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: - if hasattr(socket, "SO_REUSEADDR"): - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: - raise RuntimeError( - "tests should never set the SO_REUSEADDR " - "socket option on TCP/IP sockets!" - ) - if hasattr(socket, "SO_REUSEPORT"): - try: - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: - raise RuntimeError( - "tests should never set the SO_REUSEPORT " - "socket option on TCP/IP sockets!" - ) - except OSError: - # Python's socket module was compiled using modern headers - # thus defining SO_REUSEPORT but this process is running - # under an older kernel that does not support SO_REUSEPORT. - pass - if hasattr(socket, "SO_EXCLUSIVEADDRUSE"): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) - - sock.bind((host, 0)) - port = sock.getsockname()[1] - return port - - -@contextlib.contextmanager -def closewrapper(sock): # pragma: no cover - try: - yield sock - finally: - sock.close() - - -class dummysocket: # pragma: no cover - def __init__(self): - self.closed = False - - def close(self): - self.closed = True - - def fileno(self): - return 42 - - def setblocking(self, yesno): - self.isblocking = yesno - - def getpeername(self): - return "peername" - - -class dummychannel: # pragma: no cover - def __init__(self): - self.socket = dummysocket() - - def close(self): - self.socket.close() - - -class exitingdummy: # pragma: no cover - def __init__(self): - pass - - def handle_read_event(self): - raise asyncore.ExitNow() - - handle_write_event = handle_read_event - handle_close = handle_read_event - handle_expt_event = handle_read_event - - -class crashingdummy: - def __init__(self): - self.error_handled = False - - def handle_read_event(self): - raise Exception() - - handle_write_event = handle_read_event - handle_close = handle_read_event - handle_expt_event = handle_read_event - - def handle_error(self): - self.error_handled = True - - -# used when testing senders; just collects what it gets until newline is sent -def capture_server(evt, buf, serv): # pragma no cover - try: - serv.listen(0) - conn, addr = serv.accept() - except socket.timeout: - pass - else: - n = 200 - start = time.time() - while n > 0 and time.time() - start < 3.0: - r, w, e = select.select([conn], [], [], 0.1) - if r: - n -= 1 - data = conn.recv(10) - # keep everything except for the newline terminator - buf.write(data.replace(b"\n", b"")) - if b"\n" in data: - break - time.sleep(0.01) - - conn.close() - finally: - serv.close() - evt.set() - - -def bind_unix_socket(sock, addr): # pragma: no cover - """Bind a unix socket, raising SkipTest if PermissionError is raised.""" - assert sock.family == socket.AF_UNIX - try: - sock.bind(addr) - except PermissionError: - sock.close() - raise unittest.SkipTest("cannot bind AF_UNIX sockets") - - -def bind_af_aware(sock, addr): - """Helper function to bind a socket according to its family.""" - if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX: - # Make sure the path doesn't exist. - unlink(addr) - bind_unix_socket(sock, addr) - else: - sock.bind(addr) - - -if sys.platform.startswith("win"): # pragma: no cover - - def _waitfor(func, pathname, waitall=False): - # Perform the operation - func(pathname) - # Now setup the wait loop - if waitall: - dirname = pathname - else: - dirname, name = os.path.split(pathname) - dirname = dirname or "." - # Check for `pathname` to be removed from the filesystem. - # The exponential backoff of the timeout amounts to a total - # of ~1 second after which the deletion is probably an error - # anyway. - # Testing on an i7@4.3GHz shows that usually only 1 iteration is - # required when contention occurs. - timeout = 0.001 - while timeout < 1.0: - # Note we are only testing for the existence of the file(s) in - # the contents of the directory regardless of any security or - # access rights. If we have made it this far, we have sufficient - # permissions to do that much using Python's equivalent of the - # Windows API FindFirstFile. - # Other Windows APIs can fail or give incorrect results when - # dealing with files that are pending deletion. - L = os.listdir(dirname) - if not (L if waitall else name in L): - return - # Increase the timeout and try again - time.sleep(timeout) - timeout *= 2 - warnings.warn( - "tests may fail, delete still pending for " + pathname, - RuntimeWarning, - stacklevel=4, - ) - - def _unlink(filename): - _waitfor(os.unlink, filename) - - -else: - _unlink = os.unlink - - -def unlink(filename): - try: - _unlink(filename) - except OSError: - pass - - -def _is_ipv6_enabled(): # pragma: no cover - """Check whether IPv6 is enabled on this host.""" - if compat.HAS_IPV6: - sock = None - try: - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.bind(("::1", 0)) - return True - except socket.error: - pass - finally: - if sock: - sock.close() - return False - - -IPV6_ENABLED = _is_ipv6_enabled() - - -class HelperFunctionTests(unittest.TestCase): - def test_readwriteexc(self): - # Check exception handling behavior of read, write and _exception - - # check that ExitNow exceptions in the object handler method - # bubbles all the way up through asyncore read/write/_exception calls - tr1 = exitingdummy() - self.assertRaises(asyncore.ExitNow, asyncore.read, tr1) - self.assertRaises(asyncore.ExitNow, asyncore.write, tr1) - self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1) - - # check that an exception other than ExitNow in the object handler - # method causes the handle_error method to get called - tr2 = crashingdummy() - asyncore.read(tr2) - self.assertEqual(tr2.error_handled, True) - - tr2 = crashingdummy() - asyncore.write(tr2) - self.assertEqual(tr2.error_handled, True) - - tr2 = crashingdummy() - asyncore._exception(tr2) - self.assertEqual(tr2.error_handled, True) - - # asyncore.readwrite uses constants in the select module that - # are not present in Windows systems (see this thread: - # http://mail.python.org/pipermail/python-list/2001-October/109973.html) - # These constants should be present as long as poll is available - - @unittest.skipUnless(hasattr(select, "poll"), "select.poll required") - def test_readwrite(self): - # Check that correct methods are called by readwrite() - - attributes = ("read", "expt", "write", "closed", "error_handled") - - expected = ( - (select.POLLIN, "read"), - (select.POLLPRI, "expt"), - (select.POLLOUT, "write"), - (select.POLLERR, "closed"), - (select.POLLHUP, "closed"), - (select.POLLNVAL, "closed"), - ) - - class testobj: - def __init__(self): - self.read = False - self.write = False - self.closed = False - self.expt = False - self.error_handled = False - - def handle_read_event(self): - self.read = True - - def handle_write_event(self): - self.write = True - - def handle_close(self): - self.closed = True - - def handle_expt_event(self): - self.expt = True - - # def handle_error(self): - # self.error_handled = True - - for flag, expectedattr in expected: - tobj = testobj() - self.assertEqual(getattr(tobj, expectedattr), False) - asyncore.readwrite(tobj, flag) - - # Only the attribute modified by the routine we expect to be - # called should be True. - for attr in attributes: - self.assertEqual(getattr(tobj, attr), attr == expectedattr) - - # check that ExitNow exceptions in the object handler method - # bubbles all the way up through asyncore readwrite call - tr1 = exitingdummy() - self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag) - - # check that an exception other than ExitNow in the object handler - # method causes the handle_error method to get called - tr2 = crashingdummy() - self.assertEqual(tr2.error_handled, False) - asyncore.readwrite(tr2, flag) - self.assertEqual(tr2.error_handled, True) - - def test_closeall(self): - self.closeall_check(False) - - def test_closeall_default(self): - self.closeall_check(True) - - def closeall_check(self, usedefault): - # Check that close_all() closes everything in a given map - - l = [] - testmap = {} - for i in range(10): - c = dummychannel() - l.append(c) - self.assertEqual(c.socket.closed, False) - testmap[i] = c - - if usedefault: - socketmap = asyncore.socket_map - try: - asyncore.socket_map = testmap - asyncore.close_all() - finally: - testmap, asyncore.socket_map = asyncore.socket_map, socketmap - else: - asyncore.close_all(testmap) - - self.assertEqual(len(testmap), 0) - - for c in l: - self.assertEqual(c.socket.closed, True) - - def test_compact_traceback(self): - try: - raise Exception("I don't like spam!") - except: - real_t, real_v, real_tb = sys.exc_info() - r = asyncore.compact_traceback() - - (f, function, line), t, v, info = r - self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py") - self.assertEqual(function, "test_compact_traceback") - self.assertEqual(t, real_t) - self.assertEqual(v, real_v) - self.assertEqual(info, "[%s|%s|%s]" % (f, function, line)) - - -class DispatcherTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - asyncore.close_all() - - def test_basic(self): - d = asyncore.dispatcher() - self.assertEqual(d.readable(), True) - self.assertEqual(d.writable(), True) - - def test_repr(self): - d = asyncore.dispatcher() - self.assertEqual(repr(d), "" % id(d)) - - def test_log_info(self): - import logging - - inst = asyncore.dispatcher(map={}) - logger = DummyLogger() - inst.logger = logger - inst.log_info("message", "warning") - self.assertEqual(logger.messages, [(logging.WARN, "message")]) - - def test_log(self): - import logging - - inst = asyncore.dispatcher() - logger = DummyLogger() - inst.logger = logger - inst.log("message") - self.assertEqual(logger.messages, [(logging.DEBUG, "message")]) - - def test_unhandled(self): - import logging - - inst = asyncore.dispatcher() - logger = DummyLogger() - inst.logger = logger - - inst.handle_expt() - inst.handle_read() - inst.handle_write() - inst.handle_connect() - - expected = [ - (logging.WARN, "unhandled incoming priority event"), - (logging.WARN, "unhandled read event"), - (logging.WARN, "unhandled write event"), - (logging.WARN, "unhandled connect event"), - ] - self.assertEqual(logger.messages, expected) - - def test_strerror(self): - # refers to bug #8573 - err = asyncore._strerror(errno.EPERM) - if hasattr(os, "strerror"): - self.assertEqual(err, os.strerror(errno.EPERM)) - err = asyncore._strerror(-1) - self.assertTrue(err != "") - - -class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover - def readable(self): - return False - - def handle_connect(self): - pass - - -class DispatcherWithSendTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - asyncore.close_all() - - @reap_threads - def test_send(self): - evt = threading.Event() - sock = socket.socket() - sock.settimeout(3) - port = bind_port(sock) - - cap = BytesIO() - args = (evt, cap, sock) - t = threading.Thread(target=capture_server, args=args) - t.start() - try: - # wait a little longer for the server to initialize (it sometimes - # refuses connections on slow machines without this wait) - time.sleep(0.2) - - data = b"Suppose there isn't a 16-ton weight?" - d = dispatcherwithsend_noread() - d.create_socket() - d.connect((HOST, port)) - - # give time for socket to connect - time.sleep(0.1) - - d.send(data) - d.send(data) - d.send(b"\n") - - n = 1000 - while d.out_buffer and n > 0: # pragma: no cover - asyncore.poll() - n -= 1 - - evt.wait() - - self.assertEqual(cap.getvalue(), data * 2) - finally: - join_thread(t, timeout=TIMEOUT) - - -@unittest.skipUnless( - hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required" -) -class FileWrapperTest(unittest.TestCase): - def setUp(self): - self.d = b"It's not dead, it's sleeping!" - with open(TESTFN, "wb") as file: - file.write(self.d) - - def tearDown(self): - unlink(TESTFN) - - def test_recv(self): - fd = os.open(TESTFN, os.O_RDONLY) - w = asyncore.file_wrapper(fd) - os.close(fd) - - self.assertNotEqual(w.fd, fd) - self.assertNotEqual(w.fileno(), fd) - self.assertEqual(w.recv(13), b"It's not dead") - self.assertEqual(w.read(6), b", it's") - w.close() - self.assertRaises(OSError, w.read, 1) - - def test_send(self): - d1 = b"Come again?" - d2 = b"I want to buy some cheese." - fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND) - w = asyncore.file_wrapper(fd) - os.close(fd) - - w.write(d1) - w.send(d2) - w.close() - with open(TESTFN, "rb") as file: - self.assertEqual(file.read(), self.d + d1 + d2) - - @unittest.skipUnless( - hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required" - ) - def test_dispatcher(self): - fd = os.open(TESTFN, os.O_RDONLY) - data = [] - - class FileDispatcher(asyncore.file_dispatcher): - def handle_read(self): - data.append(self.recv(29)) - - FileDispatcher(fd) - os.close(fd) - asyncore.loop(timeout=0.01, use_poll=True, count=2) - self.assertEqual(b"".join(data), self.d) - - def test_resource_warning(self): - # Issue #11453 - got_warning = False - while got_warning is False: - # we try until we get the outcome we want because this - # test is not deterministic (gc_collect() may not - fd = os.open(TESTFN, os.O_RDONLY) - f = asyncore.file_wrapper(fd) - - os.close(fd) - - try: - with check_warnings(("", compat.ResourceWarning)): - f = None - gc_collect() - except AssertionError: # pragma: no cover - pass - else: - got_warning = True - - def test_close_twice(self): - fd = os.open(TESTFN, os.O_RDONLY) - f = asyncore.file_wrapper(fd) - os.close(fd) - - os.close(f.fd) # file_wrapper dupped fd - with self.assertRaises(OSError): - f.close() - - self.assertEqual(f.fd, -1) - # calling close twice should not fail - f.close() - - -class BaseTestHandler(asyncore.dispatcher): # pragma: no cover - def __init__(self, sock=None): - asyncore.dispatcher.__init__(self, sock) - self.flag = False - - def handle_accept(self): - raise Exception("handle_accept not supposed to be called") - - def handle_accepted(self): - raise Exception("handle_accepted not supposed to be called") - - def handle_connect(self): - raise Exception("handle_connect not supposed to be called") - - def handle_expt(self): - raise Exception("handle_expt not supposed to be called") - - def handle_close(self): - raise Exception("handle_close not supposed to be called") - - def handle_error(self): - raise - - -class BaseServer(asyncore.dispatcher): - """A server which listens on an address and dispatches the - connection to a handler. - """ - - def __init__(self, family, addr, handler=BaseTestHandler): - asyncore.dispatcher.__init__(self) - self.create_socket(family) - self.set_reuse_addr() - bind_af_aware(self.socket, addr) - self.listen(5) - self.handler = handler - - @property - def address(self): - return self.socket.getsockname() - - def handle_accepted(self, sock, addr): - self.handler(sock) - - def handle_error(self): # pragma: no cover - raise - - -class BaseClient(BaseTestHandler): - def __init__(self, family, address): - BaseTestHandler.__init__(self) - self.create_socket(family) - self.connect(address) - - def handle_connect(self): - pass - - -class BaseTestAPI: - def tearDown(self): - asyncore.close_all(ignore_all=True) - - def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover - timeout = float(timeout) / 100 - count = 100 - while asyncore.socket_map and count > 0: - asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) - if instance.flag: - return - count -= 1 - time.sleep(timeout) - self.fail("flag not set") - - def test_handle_connect(self): - # make sure handle_connect is called on connect() - - class TestClient(BaseClient): - def handle_connect(self): - self.flag = True - - server = BaseServer(self.family, self.addr) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_accept(self): - # make sure handle_accept() is called when a client connects - - class TestListener(BaseTestHandler): - def __init__(self, family, addr): - BaseTestHandler.__init__(self) - self.create_socket(family) - bind_af_aware(self.socket, addr) - self.listen(5) - self.address = self.socket.getsockname() - - def handle_accept(self): - self.flag = True - - server = TestListener(self.family, self.addr) - client = BaseClient(self.family, server.address) - self.loop_waiting_for_flag(server) - - def test_handle_accepted(self): - # make sure handle_accepted() is called when a client connects - - class TestListener(BaseTestHandler): - def __init__(self, family, addr): - BaseTestHandler.__init__(self) - self.create_socket(family) - bind_af_aware(self.socket, addr) - self.listen(5) - self.address = self.socket.getsockname() - - def handle_accept(self): - asyncore.dispatcher.handle_accept(self) - - def handle_accepted(self, sock, addr): - sock.close() - self.flag = True - - server = TestListener(self.family, self.addr) - client = BaseClient(self.family, server.address) - self.loop_waiting_for_flag(server) - - def test_handle_read(self): - # make sure handle_read is called on data received - - class TestClient(BaseClient): - def handle_read(self): - self.flag = True - - class TestHandler(BaseTestHandler): - def __init__(self, conn): - BaseTestHandler.__init__(self, conn) - self.send(b"x" * 1024) - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_write(self): - # make sure handle_write is called - - class TestClient(BaseClient): - def handle_write(self): - self.flag = True - - server = BaseServer(self.family, self.addr) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_close(self): - # make sure handle_close is called when the other end closes - # the connection - - class TestClient(BaseClient): - def handle_read(self): - # in order to make handle_close be called we are supposed - # to make at least one recv() call - self.recv(1024) - - def handle_close(self): - self.flag = True - self.close() - - class TestHandler(BaseTestHandler): - def __init__(self, conn): - BaseTestHandler.__init__(self, conn) - self.close() - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_close_after_conn_broken(self): - # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and - # #11265). - - data = b"\0" * 128 - - class TestClient(BaseClient): - def handle_write(self): - self.send(data) - - def handle_close(self): - self.flag = True - self.close() - - def handle_expt(self): # pragma: no cover - # needs to exist for MacOS testing - self.flag = True - self.close() - - class TestHandler(BaseTestHandler): - def handle_read(self): - self.recv(len(data)) - self.close() - - def writable(self): - return False - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - @unittest.skipIf( - sys.platform.startswith("sunos"), "OOB support is broken on Solaris" - ) - def test_handle_expt(self): - # Make sure handle_expt is called on OOB data received. - # Note: this might fail on some platforms as OOB data is - # tenuously supported and rarely used. - if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: - self.skipTest("Not applicable to AF_UNIX sockets.") - - if sys.platform == "darwin" and self.use_poll: # pragma: no cover - self.skipTest("poll may fail on macOS; see issue #28087") - - class TestClient(BaseClient): - def handle_expt(self): - self.socket.recv(1024, socket.MSG_OOB) - self.flag = True - - class TestHandler(BaseTestHandler): - def __init__(self, conn): - BaseTestHandler.__init__(self, conn) - self.socket.send(compat.tobytes(chr(244)), socket.MSG_OOB) - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_error(self): - class TestClient(BaseClient): - def handle_write(self): - 1.0 / 0 - - def handle_error(self): - self.flag = True - try: - raise - except ZeroDivisionError: - pass - else: # pragma: no cover - raise Exception("exception not raised") - - server = BaseServer(self.family, self.addr) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_connection_attributes(self): - server = BaseServer(self.family, self.addr) - client = BaseClient(self.family, server.address) - - # we start disconnected - self.assertFalse(server.connected) - self.assertTrue(server.accepting) - # this can't be taken for granted across all platforms - # self.assertFalse(client.connected) - self.assertFalse(client.accepting) - - # execute some loops so that client connects to server - asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) - self.assertFalse(server.connected) - self.assertTrue(server.accepting) - self.assertTrue(client.connected) - self.assertFalse(client.accepting) - - # disconnect the client - client.close() - self.assertFalse(server.connected) - self.assertTrue(server.accepting) - self.assertFalse(client.connected) - self.assertFalse(client.accepting) - - # stop serving - server.close() - self.assertFalse(server.connected) - self.assertFalse(server.accepting) - - def test_create_socket(self): - s = asyncore.dispatcher() - s.create_socket(self.family) - # self.assertEqual(s.socket.type, socket.SOCK_STREAM) - self.assertEqual(s.socket.family, self.family) - self.assertEqual(s.socket.gettimeout(), 0) - # self.assertFalse(s.socket.get_inheritable()) - - def test_bind(self): - if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: - self.skipTest("Not applicable to AF_UNIX sockets.") - s1 = asyncore.dispatcher() - s1.create_socket(self.family) - s1.bind(self.addr) - s1.listen(5) - port = s1.socket.getsockname()[1] - - s2 = asyncore.dispatcher() - s2.create_socket(self.family) - # EADDRINUSE indicates the socket was correctly bound - self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) - - def test_set_reuse_addr(self): # pragma: no cover - if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: - self.skipTest("Not applicable to AF_UNIX sockets.") - - with closewrapper(socket.socket(self.family)) as sock: - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - except OSError: - unittest.skip("SO_REUSEADDR not supported on this platform") - else: - # if SO_REUSEADDR succeeded for sock we expect asyncore - # to do the same - s = asyncore.dispatcher(socket.socket(self.family)) - self.assertFalse( - s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) - ) - s.socket.close() - s.create_socket(self.family) - s.set_reuse_addr() - self.assertTrue( - s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) - ) - - @reap_threads - def test_quick_connect(self): # pragma: no cover - # see: http://bugs.python.org/issue10340 - if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())): - self.skipTest("test specific to AF_INET and AF_INET6") - - server = BaseServer(self.family, self.addr) - # run the thread 500 ms: the socket should be connected in 200 ms - t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5)) - t.start() - try: - sock = socket.socket(self.family, socket.SOCK_STREAM) - with closewrapper(sock) as s: - s.settimeout(0.2) - s.setsockopt( - socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) - ) - - try: - s.connect(server.address) - except OSError: - pass - finally: - join_thread(t, timeout=TIMEOUT) - - -class TestAPI_UseIPv4Sockets(BaseTestAPI): - family = socket.AF_INET - addr = (HOST, 0) - - -@unittest.skipUnless(IPV6_ENABLED, "IPv6 support required") -class TestAPI_UseIPv6Sockets(BaseTestAPI): - family = socket.AF_INET6 - addr = (HOSTv6, 0) - - -@unittest.skipUnless(HAS_UNIX_SOCKETS, "Unix sockets required") -class TestAPI_UseUnixSockets(BaseTestAPI): - if HAS_UNIX_SOCKETS: - family = socket.AF_UNIX - addr = TESTFN - - def tearDown(self): - unlink(self.addr) - BaseTestAPI.tearDown(self) - - -class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase): - use_poll = False - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase): - use_poll = True - - -class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase): - use_poll = False - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase): - use_poll = True - - -class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase): - use_poll = False - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase): - use_poll = True - - -class Test__strerror(unittest.TestCase): - def _callFUT(self, err): - from waitress.wasyncore import _strerror - - return _strerror(err) - - def test_gardenpath(self): - self.assertEqual(self._callFUT(1), "Operation not permitted") - - def test_unknown(self): - self.assertEqual(self._callFUT("wut"), "Unknown error wut") - - -class Test_read(unittest.TestCase): - def _callFUT(self, dispatcher): - from waitress.wasyncore import read - - return read(dispatcher) - - def test_gardenpath(self): - inst = DummyDispatcher() - self._callFUT(inst) - self.assertTrue(inst.read_event_handled) - self.assertFalse(inst.error_handled) - - def test_reraised(self): - from waitress.wasyncore import ExitNow - - inst = DummyDispatcher(ExitNow) - self.assertRaises(ExitNow, self._callFUT, inst) - self.assertTrue(inst.read_event_handled) - self.assertFalse(inst.error_handled) - - def test_non_reraised(self): - inst = DummyDispatcher(OSError) - self._callFUT(inst) - self.assertTrue(inst.read_event_handled) - self.assertTrue(inst.error_handled) - - -class Test_write(unittest.TestCase): - def _callFUT(self, dispatcher): - from waitress.wasyncore import write - - return write(dispatcher) - - def test_gardenpath(self): - inst = DummyDispatcher() - self._callFUT(inst) - self.assertTrue(inst.write_event_handled) - self.assertFalse(inst.error_handled) - - def test_reraised(self): - from waitress.wasyncore import ExitNow - - inst = DummyDispatcher(ExitNow) - self.assertRaises(ExitNow, self._callFUT, inst) - self.assertTrue(inst.write_event_handled) - self.assertFalse(inst.error_handled) - - def test_non_reraised(self): - inst = DummyDispatcher(OSError) - self._callFUT(inst) - self.assertTrue(inst.write_event_handled) - self.assertTrue(inst.error_handled) - - -class Test__exception(unittest.TestCase): - def _callFUT(self, dispatcher): - from waitress.wasyncore import _exception - - return _exception(dispatcher) - - def test_gardenpath(self): - inst = DummyDispatcher() - self._callFUT(inst) - self.assertTrue(inst.expt_event_handled) - self.assertFalse(inst.error_handled) - - def test_reraised(self): - from waitress.wasyncore import ExitNow - - inst = DummyDispatcher(ExitNow) - self.assertRaises(ExitNow, self._callFUT, inst) - self.assertTrue(inst.expt_event_handled) - self.assertFalse(inst.error_handled) - - def test_non_reraised(self): - inst = DummyDispatcher(OSError) - self._callFUT(inst) - self.assertTrue(inst.expt_event_handled) - self.assertTrue(inst.error_handled) - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class Test_readwrite(unittest.TestCase): - def _callFUT(self, obj, flags): - from waitress.wasyncore import readwrite - - return readwrite(obj, flags) - - def test_handle_read_event(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.read_event_handled) - - def test_handle_write_event(self): - flags = 0 - flags |= select.POLLOUT - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.write_event_handled) - - def test_handle_expt_event(self): - flags = 0 - flags |= select.POLLPRI - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.expt_event_handled) - - def test_handle_close(self): - flags = 0 - flags |= select.POLLHUP - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.close_handled) - - def test_socketerror_not_in_disconnected(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY")) - self._callFUT(inst, flags) - self.assertTrue(inst.read_event_handled) - self.assertTrue(inst.error_handled) - - def test_socketerror_in_disconnected(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET")) - self._callFUT(inst, flags) - self.assertTrue(inst.read_event_handled) - self.assertTrue(inst.close_handled) - - def test_exception_in_reraised(self): - from waitress import wasyncore - - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(wasyncore.ExitNow) - self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags) - self.assertTrue(inst.read_event_handled) - - def test_exception_not_in_reraised(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(ValueError) - self._callFUT(inst, flags) - self.assertTrue(inst.error_handled) - - -class Test_poll(unittest.TestCase): - def _callFUT(self, timeout=0.0, map=None): - from waitress.wasyncore import poll - - return poll(timeout, map) - - def test_nothing_writable_nothing_readable_but_map_not_empty(self): - # i read the mock.patch docs. nerp. - dummy_time = DummyTime() - map = {0: DummyDispatcher()} - try: - from waitress import wasyncore - - old_time = wasyncore.time - wasyncore.time = dummy_time - result = self._callFUT(map=map) - finally: - wasyncore.time = old_time - self.assertEqual(result, None) - self.assertEqual(dummy_time.sleepvals, [0.0]) - - def test_select_raises_EINTR(self): - # i read the mock.patch docs. nerp. - dummy_select = DummySelect(select.error(errno.EINTR)) - disp = DummyDispatcher() - disp.readable = lambda: True - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - result = self._callFUT(map=map) - finally: - wasyncore.select = old_select - self.assertEqual(result, None) - self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) - - def test_select_raises_non_EINTR(self): - # i read the mock.patch docs. nerp. - dummy_select = DummySelect(select.error(errno.EBADF)) - disp = DummyDispatcher() - disp.readable = lambda: True - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - self.assertRaises(select.error, self._callFUT, map=map) - finally: - wasyncore.select = old_select - self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) - - -class Test_poll2(unittest.TestCase): - def _callFUT(self, timeout=0.0, map=None): - from waitress.wasyncore import poll2 - - return poll2(timeout, map) - - def test_select_raises_EINTR(self): - # i read the mock.patch docs. nerp. - pollster = DummyPollster(exc=select.error(errno.EINTR)) - dummy_select = DummySelect(pollster=pollster) - disp = DummyDispatcher() - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - self._callFUT(map=map) - finally: - wasyncore.select = old_select - self.assertEqual(pollster.polled, [0.0]) - - def test_select_raises_non_EINTR(self): - # i read the mock.patch docs. nerp. - pollster = DummyPollster(exc=select.error(errno.EBADF)) - dummy_select = DummySelect(pollster=pollster) - disp = DummyDispatcher() - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - self.assertRaises(select.error, self._callFUT, map=map) - finally: - wasyncore.select = old_select - self.assertEqual(pollster.polled, [0.0]) - - -class Test_dispatcher(unittest.TestCase): - def _makeOne(self, sock=None, map=None): - from waitress.wasyncore import dispatcher - - return dispatcher(sock=sock, map=map) - - def test_unexpected_getpeername_exc(self): - sock = dummysocket() - - def getpeername(): - raise socket.error(errno.EBADF) - - map = {} - sock.getpeername = getpeername - self.assertRaises(socket.error, self._makeOne, sock=sock, map=map) - self.assertEqual(map, {}) - - def test___repr__accepting(self): - sock = dummysocket() - map = {} - inst = self._makeOne(sock=sock, map=map) - inst.accepting = True - inst.addr = ("localhost", 8080) - result = repr(inst) - expected = " Date: Sun, 2 Feb 2020 22:48:45 -0800 Subject: Move source code to src folder --- src/waitress/__init__.py | 45 +++ src/waitress/__main__.py | 3 + src/waitress/adjustments.py | 515 +++++++++++++++++++++++++++++++ src/waitress/buffers.py | 308 +++++++++++++++++++ src/waitress/channel.py | 414 +++++++++++++++++++++++++ src/waitress/compat.py | 179 +++++++++++ src/waitress/parser.py | 413 +++++++++++++++++++++++++ src/waitress/proxy_headers.py | 333 ++++++++++++++++++++ src/waitress/receiver.py | 186 ++++++++++++ src/waitress/rfc7230.py | 52 ++++ src/waitress/runner.py | 286 +++++++++++++++++ src/waitress/server.py | 436 ++++++++++++++++++++++++++ src/waitress/task.py | 570 ++++++++++++++++++++++++++++++++++ src/waitress/trigger.py | 203 +++++++++++++ src/waitress/utilities.py | 320 +++++++++++++++++++ src/waitress/wasyncore.py | 693 ++++++++++++++++++++++++++++++++++++++++++ waitress/__init__.py | 45 --- waitress/__main__.py | 3 - waitress/adjustments.py | 515 ------------------------------- waitress/buffers.py | 308 ------------------- waitress/channel.py | 414 ------------------------- waitress/compat.py | 179 ----------- waitress/parser.py | 413 ------------------------- waitress/proxy_headers.py | 333 -------------------- waitress/receiver.py | 186 ------------ waitress/rfc7230.py | 52 ---- waitress/runner.py | 286 ----------------- waitress/server.py | 436 -------------------------- waitress/task.py | 570 ---------------------------------- waitress/trigger.py | 203 ------------- waitress/utilities.py | 320 ------------------- waitress/wasyncore.py | 693 ------------------------------------------ 32 files changed, 4956 insertions(+), 4956 deletions(-) create mode 100644 src/waitress/__init__.py create mode 100644 src/waitress/__main__.py create mode 100644 src/waitress/adjustments.py create mode 100644 src/waitress/buffers.py create mode 100644 src/waitress/channel.py create mode 100644 src/waitress/compat.py create mode 100644 src/waitress/parser.py create mode 100644 src/waitress/proxy_headers.py create mode 100644 src/waitress/receiver.py create mode 100644 src/waitress/rfc7230.py create mode 100644 src/waitress/runner.py create mode 100644 src/waitress/server.py create mode 100644 src/waitress/task.py create mode 100644 src/waitress/trigger.py create mode 100644 src/waitress/utilities.py create mode 100644 src/waitress/wasyncore.py delete mode 100644 waitress/__init__.py delete mode 100644 waitress/__main__.py delete mode 100644 waitress/adjustments.py delete mode 100644 waitress/buffers.py delete mode 100644 waitress/channel.py delete mode 100644 waitress/compat.py delete mode 100644 waitress/parser.py delete mode 100644 waitress/proxy_headers.py delete mode 100644 waitress/receiver.py delete mode 100644 waitress/rfc7230.py delete mode 100644 waitress/runner.py delete mode 100644 waitress/server.py delete mode 100644 waitress/task.py delete mode 100644 waitress/trigger.py delete mode 100644 waitress/utilities.py delete mode 100644 waitress/wasyncore.py diff --git a/src/waitress/__init__.py b/src/waitress/__init__.py new file mode 100644 index 0000000..e6e5911 --- /dev/null +++ b/src/waitress/__init__.py @@ -0,0 +1,45 @@ +from waitress.server import create_server +import logging + + +def serve(app, **kw): + _server = kw.pop("_server", create_server) # test shim + _quiet = kw.pop("_quiet", False) # test shim + _profile = kw.pop("_profile", False) # test shim + if not _quiet: # pragma: no cover + # idempotent if logging has already been set up + logging.basicConfig() + server = _server(app, **kw) + if not _quiet: # pragma: no cover + server.print_listen("Serving on http://{}:{}") + if _profile: # pragma: no cover + profile("server.run()", globals(), locals(), (), False) + else: + server.run() + + +def serve_paste(app, global_conf, **kw): + serve(app, **kw) + return 0 + + +def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover + # runs a command under the profiler and print profiling output at shutdown + import os + import profile + import pstats + import tempfile + + fd, fn = tempfile.mkstemp() + try: + profile.runctx(cmd, globals, locals, fn) + stats = pstats.Stats(fn) + stats.strip_dirs() + # calls,time,cumulative and cumulative,calls,time are useful + stats.sort_stats(*(sort_order or ("cumulative", "calls", "time"))) + if callers: + stats.print_callers(0.3) + else: + stats.print_stats(0.3) + finally: + os.remove(fn) diff --git a/src/waitress/__main__.py b/src/waitress/__main__.py new file mode 100644 index 0000000..9bcd07e --- /dev/null +++ b/src/waitress/__main__.py @@ -0,0 +1,3 @@ +from waitress.runner import run # pragma nocover + +run() # pragma nocover diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py new file mode 100644 index 0000000..93439ea --- /dev/null +++ b/src/waitress/adjustments.py @@ -0,0 +1,515 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Adjustments are tunable parameters. +""" +import getopt +import socket +import warnings + +from .proxy_headers import PROXY_HEADERS +from .compat import ( + PY2, + WIN, + string_types, + HAS_IPV6, +) + +truthy = frozenset(("t", "true", "y", "yes", "on", "1")) + +KNOWN_PROXY_HEADERS = frozenset( + header.lower().replace("_", "-") for header in PROXY_HEADERS +) + + +def asbool(s): + """ Return the boolean value ``True`` if the case-lowered value of string + input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise + return the boolean value ``False``. If ``s`` is the value ``None``, + return ``False``. If ``s`` is already one of the boolean values ``True`` + or ``False``, return it.""" + if s is None: + return False + if isinstance(s, bool): + return s + s = str(s).strip() + return s.lower() in truthy + + +def asoctal(s): + """Convert the given octal string to an actual number.""" + return int(s, 8) + + +def aslist_cronly(value): + if isinstance(value, string_types): + value = filter(None, [x.strip() for x in value.splitlines()]) + return list(value) + + +def aslist(value): + """ Return a list of strings, separating the input based on newlines + and, if flatten=True (the default), also split on spaces within + each line.""" + values = aslist_cronly(value) + result = [] + for value in values: + subvalues = value.split() + result.extend(subvalues) + return result + + +def asset(value): + return set(aslist(value)) + + +def slash_fixed_str(s): + s = s.strip() + if s: + # always have a leading slash, replace any number of leading slashes + # with a single slash, and strip any trailing slashes + s = "/" + s.lstrip("/").rstrip("/") + return s + + +def str_iftruthy(s): + return str(s) if s else None + + +def as_socket_list(sockets): + """Checks if the elements in the list are of type socket and + removes them if not.""" + return [sock for sock in sockets if isinstance(sock, socket.socket)] + + +class _str_marker(str): + pass + + +class _int_marker(int): + pass + + +class _bool_marker(object): + pass + + +class Adjustments(object): + """This class contains tunable parameters. + """ + + _params = ( + ("host", str), + ("port", int), + ("ipv4", asbool), + ("ipv6", asbool), + ("listen", aslist), + ("threads", int), + ("trusted_proxy", str_iftruthy), + ("trusted_proxy_count", int), + ("trusted_proxy_headers", asset), + ("log_untrusted_proxy_headers", asbool), + ("clear_untrusted_proxy_headers", asbool), + ("url_scheme", str), + ("url_prefix", slash_fixed_str), + ("backlog", int), + ("recv_bytes", int), + ("send_bytes", int), + ("outbuf_overflow", int), + ("outbuf_high_watermark", int), + ("inbuf_overflow", int), + ("connection_limit", int), + ("cleanup_interval", int), + ("channel_timeout", int), + ("log_socket_errors", asbool), + ("max_request_header_size", int), + ("max_request_body_size", int), + ("expose_tracebacks", asbool), + ("ident", str_iftruthy), + ("asyncore_loop_timeout", int), + ("asyncore_use_poll", asbool), + ("unix_socket", str), + ("unix_socket_perms", asoctal), + ("sockets", as_socket_list), + ) + + _param_map = dict(_params) + + # hostname or IP address to listen on + host = _str_marker("0.0.0.0") + + # TCP port to listen on + port = _int_marker(8080) + + listen = ["{}:{}".format(host, port)] + + # number of threads available for tasks + threads = 4 + + # Host allowed to overrid ``wsgi.url_scheme`` via header + trusted_proxy = None + + # How many proxies we trust when chained + # + # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]" + # + # or + # + # Forwarded: for=192.0.2.1, For="[2001:db8::1]" + # + # means there were (potentially), two proxies involved. If we know there is + # only 1 valid proxy, then that initial IP address "192.0.2.1" is not + # trusted and we completely ignore it. If there are two trusted proxies in + # the path, this value should be set to a higher number. + trusted_proxy_count = None + + # Which of the proxy headers should we trust, this is a set where you + # either specify forwarded or one or more of forwarded-host, forwarded-for, + # forwarded-proto, forwarded-port. + trusted_proxy_headers = set() + + # Would you like waitress to log warnings about untrusted proxy headers + # that were encountered while processing the proxy headers? This only makes + # sense to set when you have a trusted_proxy, and you expect the upstream + # proxy server to filter invalid headers + log_untrusted_proxy_headers = False + + # Should waitress clear any proxy headers that are not deemed trusted from + # the environ? Change to True by default in 2.x + clear_untrusted_proxy_headers = _bool_marker + + # default ``wsgi.url_scheme`` value + url_scheme = "http" + + # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO`` + # when nonempty + url_prefix = "" + + # server identity (sent in Server: header) + ident = "waitress" + + # backlog is the value waitress passes to pass to socket.listen() This is + # the maximum number of incoming TCP connections that will wait in an OS + # queue for an available channel. From listen(1): "If a connection + # request arrives when the queue is full, the client may receive an error + # with an indication of ECONNREFUSED or, if the underlying protocol + # supports retransmission, the request may be ignored so that a later + # reattempt at connection succeeds." + backlog = 1024 + + # recv_bytes is the argument to pass to socket.recv(). + recv_bytes = 8192 + + # deprecated setting controls how many bytes will be buffered before + # being flushed to the socket + send_bytes = 1 + + # A tempfile should be created if the pending output is larger than + # outbuf_overflow, which is measured in bytes. The default is 1MB. This + # is conservative. + outbuf_overflow = 1048576 + + # The app_iter will pause when pending output is larger than this value + # in bytes. + outbuf_high_watermark = 16777216 + + # A tempfile should be created if the pending input is larger than + # inbuf_overflow, which is measured in bytes. The default is 512K. This + # is conservative. + inbuf_overflow = 524288 + + # Stop creating new channels if too many are already active (integer). + # Each channel consumes at least one file descriptor, and, depending on + # the input and output body sizes, potentially up to three. The default + # is conservative, but you may need to increase the number of file + # descriptors available to the Waitress process on most platforms in + # order to safely change it (see ``ulimit -a`` "open files" setting). + # Note that this doesn't control the maximum number of TCP connections + # that can be waiting for processing; the ``backlog`` argument controls + # that. + connection_limit = 100 + + # Minimum seconds between cleaning up inactive channels. + cleanup_interval = 30 + + # Maximum seconds to leave an inactive connection open. + channel_timeout = 120 + + # Boolean: turn off to not log premature client disconnects. + log_socket_errors = True + + # maximum number of bytes of all request headers combined (256K default) + max_request_header_size = 262144 + + # maximum number of bytes in request body (1GB default) + max_request_body_size = 1073741824 + + # expose tracebacks of uncaught exceptions + expose_tracebacks = False + + # Path to a Unix domain socket to use. + unix_socket = None + + # Path to a Unix domain socket to use. + unix_socket_perms = 0o600 + + # The socket options to set on receiving a connection. It is a list of + # (level, optname, value) tuples. TCP_NODELAY disables the Nagle + # algorithm for writes (Waitress already buffers its writes). + socket_options = [ + (socket.SOL_TCP, socket.TCP_NODELAY, 1), + ] + + # The asyncore.loop timeout value + asyncore_loop_timeout = 1 + + # The asyncore.loop flag to use poll() instead of the default select(). + asyncore_use_poll = False + + # Enable IPv4 by default + ipv4 = True + + # Enable IPv6 by default + ipv6 = True + + # A list of sockets that waitress will use to accept connections. They can + # be used for e.g. socket activation + sockets = [] + + def __init__(self, **kw): + + if "listen" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if listen is set.") + + if "listen" in kw and "sockets" in kw: + raise ValueError("socket may not be set if listen is set.") + + if "sockets" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if sockets is set.") + + if "sockets" in kw and "unix_socket" in kw: + raise ValueError("unix_socket may not be set if sockets is set") + + if "unix_socket" in kw and ("host" in kw or "port" in kw): + raise ValueError("unix_socket may not be set if host or port is set") + + if "unix_socket" in kw and "listen" in kw: + raise ValueError("unix_socket may not be set if listen is set") + + if "send_bytes" in kw: + warnings.warn( + "send_bytes will be removed in a future release", DeprecationWarning, + ) + + for k, v in kw.items(): + if k not in self._param_map: + raise ValueError("Unknown adjustment %r" % k) + setattr(self, k, self._param_map[k](v)) + + if not isinstance(self.host, _str_marker) or not isinstance( + self.port, _int_marker + ): + self.listen = ["{}:{}".format(self.host, self.port)] + + enabled_families = socket.AF_UNSPEC + + if not self.ipv4 and not HAS_IPV6: # pragma: no cover + raise ValueError( + "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start." + ) + + if self.ipv4 and not self.ipv6: + enabled_families = socket.AF_INET + + if not self.ipv4 and self.ipv6 and HAS_IPV6: + enabled_families = socket.AF_INET6 + + wanted_sockets = [] + hp_pairs = [] + for i in self.listen: + if ":" in i: + (host, port) = i.rsplit(":", 1) + + # IPv6 we need to make sure that we didn't split on the address + if "]" in port: # pragma: nocover + (host, port) = (i, str(self.port)) + else: + (host, port) = (i, str(self.port)) + + if WIN and PY2: # pragma: no cover + try: + # Try turning the port into an integer + port = int(port) + + except Exception: + raise ValueError( + "Windows does not support service names instead of port numbers" + ) + + try: + if "[" in host and "]" in host: # pragma: nocover + host = host.strip("[").rstrip("]") + + if host == "*": + host = None + + for s in socket.getaddrinfo( + host, + port, + enabled_families, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE, + ): + (family, socktype, proto, _, sockaddr) = s + + # It seems that getaddrinfo() may sometimes happily return + # the same result multiple times, this of course makes + # bind() very unhappy... + # + # Split on %, and drop the zone-index from the host in the + # sockaddr. Works around a bug in OS X whereby + # getaddrinfo() returns the same link-local interface with + # two different zone-indices (which makes no sense what so + # ever...) yet treats them equally when we attempt to bind(). + if ( + sockaddr[1] == 0 + or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs + ): + wanted_sockets.append((family, socktype, proto, sockaddr)) + hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) + + except Exception: + raise ValueError("Invalid host/port specified.") + + if self.trusted_proxy_count is not None and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_count has no meaning without setting " "trusted_proxy" + ) + + elif self.trusted_proxy_count is None: + self.trusted_proxy_count = 1 + + if self.trusted_proxy_headers and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_headers has no meaning without setting " "trusted_proxy" + ) + + if self.trusted_proxy_headers: + self.trusted_proxy_headers = { + header.lower() for header in self.trusted_proxy_headers + } + + unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS + if unknown_values: + raise ValueError( + "Received unknown trusted_proxy_headers value (%s) expected one " + "of %s" + % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) + ) + + if ( + "forwarded" in self.trusted_proxy_headers + and self.trusted_proxy_headers - {"forwarded"} + ): + raise ValueError( + "The Forwarded proxy header and the " + "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually " + "exclusive. Can't trust both!" + ) + + elif self.trusted_proxy is not None: + warnings.warn( + "No proxy headers were marked as trusted, but trusted_proxy was set. " + "Implicitly trusting X-Forwarded-Proto for backwards compatibility. " + "This will be removed in future versions of waitress.", + DeprecationWarning, + ) + self.trusted_proxy_headers = {"x-forwarded-proto"} + + if self.clear_untrusted_proxy_headers is _bool_marker: + warnings.warn( + "In future versions of Waitress clear_untrusted_proxy_headers will be " + "set to True by default. You may opt-out by setting this value to " + "False, or opt-in explicitly by setting this to True.", + DeprecationWarning, + ) + self.clear_untrusted_proxy_headers = False + + self.listen = wanted_sockets + + self.check_sockets(self.sockets) + + @classmethod + def parse_args(cls, argv): + """Pre-parse command line arguments for input into __init__. Note that + this does not cast values into adjustment types, it just creates a + dictionary suitable for passing into __init__, where __init__ does the + casting. + """ + long_opts = ["help", "call"] + for opt, cast in cls._params: + opt = opt.replace("_", "-") + if cast is asbool: + long_opts.append(opt) + long_opts.append("no-" + opt) + else: + long_opts.append(opt + "=") + + kw = { + "help": False, + "call": False, + } + + opts, args = getopt.getopt(argv, "", long_opts) + for opt, value in opts: + param = opt.lstrip("-").replace("-", "_") + + if param == "listen": + kw["listen"] = "{} {}".format(kw.get("listen", ""), value) + continue + + if param.startswith("no_"): + param = param[3:] + kw[param] = "false" + elif param in ("help", "call"): + kw[param] = True + elif cls._param_map[param] is asbool: + kw[param] = "true" + else: + kw[param] = value + + return kw, args + + @classmethod + def check_sockets(cls, sockets): + has_unix_socket = False + has_inet_socket = False + has_unsupported_socket = False + for sock in sockets: + if ( + sock.family == socket.AF_INET or sock.family == socket.AF_INET6 + ) and sock.type == socket.SOCK_STREAM: + has_inet_socket = True + elif ( + hasattr(socket, "AF_UNIX") + and sock.family == socket.AF_UNIX + and sock.type == socket.SOCK_STREAM + ): + has_unix_socket = True + else: + has_unsupported_socket = True + if has_unix_socket and has_inet_socket: + raise ValueError("Internet and UNIX sockets may not be mixed.") + if has_unsupported_socket: + raise ValueError("Only Internet or UNIX stream sockets may be used.") diff --git a/src/waitress/buffers.py b/src/waitress/buffers.py new file mode 100644 index 0000000..04f6b42 --- /dev/null +++ b/src/waitress/buffers.py @@ -0,0 +1,308 @@ +############################################################################## +# +# Copyright (c) 2001-2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Buffers +""" +from io import BytesIO + +# copy_bytes controls the size of temp. strings for shuffling data around. +COPY_BYTES = 1 << 18 # 256K + +# The maximum number of bytes to buffer in a simple string. +STRBUF_LIMIT = 8192 + + +class FileBasedBuffer(object): + + remain = 0 + + def __init__(self, file, from_buffer=None): + self.file = file + if from_buffer is not None: + from_file = from_buffer.getfile() + read_pos = from_file.tell() + from_file.seek(0) + while True: + data = from_file.read(COPY_BYTES) + if not data: + break + file.write(data) + self.remain = int(file.tell() - read_pos) + from_file.seek(read_pos) + file.seek(read_pos) + + def __len__(self): + return self.remain + + def __nonzero__(self): + return True + + __bool__ = __nonzero__ # py3 + + def append(self, s): + file = self.file + read_pos = file.tell() + file.seek(0, 2) + file.write(s) + file.seek(read_pos) + self.remain = self.remain + len(s) + + def get(self, numbytes=-1, skip=False): + file = self.file + if not skip: + read_pos = file.tell() + if numbytes < 0: + # Read all + res = file.read() + else: + res = file.read(numbytes) + if skip: + self.remain -= len(res) + else: + file.seek(read_pos) + return res + + def skip(self, numbytes, allow_prune=0): + if self.remain < numbytes: + raise ValueError( + "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain) + ) + self.file.seek(numbytes, 1) + self.remain = self.remain - numbytes + + def newfile(self): + raise NotImplementedError() + + def prune(self): + file = self.file + if self.remain == 0: + read_pos = file.tell() + file.seek(0, 2) + sz = file.tell() + file.seek(read_pos) + if sz == 0: + # Nothing to prune. + return + nf = self.newfile() + while True: + data = file.read(COPY_BYTES) + if not data: + break + nf.write(data) + self.file = nf + + def getfile(self): + return self.file + + def close(self): + if hasattr(self.file, "close"): + self.file.close() + self.remain = 0 + + +class TempfileBasedBuffer(FileBasedBuffer): + def __init__(self, from_buffer=None): + FileBasedBuffer.__init__(self, self.newfile(), from_buffer) + + def newfile(self): + from tempfile import TemporaryFile + + return TemporaryFile("w+b") + + +class BytesIOBasedBuffer(FileBasedBuffer): + def __init__(self, from_buffer=None): + if from_buffer is not None: + FileBasedBuffer.__init__(self, BytesIO(), from_buffer) + else: + # Shortcut. :-) + self.file = BytesIO() + + def newfile(self): + return BytesIO() + + +def _is_seekable(fp): + if hasattr(fp, "seekable"): + return fp.seekable() + return hasattr(fp, "seek") and hasattr(fp, "tell") + + +class ReadOnlyFileBasedBuffer(FileBasedBuffer): + # used as wsgi.file_wrapper + + def __init__(self, file, block_size=32768): + self.file = file + self.block_size = block_size # for __iter__ + + def prepare(self, size=None): + if _is_seekable(self.file): + start_pos = self.file.tell() + self.file.seek(0, 2) + end_pos = self.file.tell() + self.file.seek(start_pos) + fsize = end_pos - start_pos + if size is None: + self.remain = fsize + else: + self.remain = min(fsize, size) + return self.remain + + def get(self, numbytes=-1, skip=False): + # never read more than self.remain (it can be user-specified) + if numbytes == -1 or numbytes > self.remain: + numbytes = self.remain + file = self.file + if not skip: + read_pos = file.tell() + res = file.read(numbytes) + if skip: + self.remain -= len(res) + else: + file.seek(read_pos) + return res + + def __iter__(self): # called by task if self.filelike has no seek/tell + return self + + def next(self): + val = self.file.read(self.block_size) + if not val: + raise StopIteration + return val + + __next__ = next # py3 + + def append(self, s): + raise NotImplementedError + + +class OverflowableBuffer(object): + """ + This buffer implementation has four stages: + - No data + - Bytes-based buffer + - BytesIO-based buffer + - Temporary file storage + The first two stages are fastest for simple transfers. + """ + + overflowed = False + buf = None + strbuf = b"" # Bytes-based buffer. + + def __init__(self, overflow): + # overflow is the maximum to be stored in a StringIO buffer. + self.overflow = overflow + + def __len__(self): + buf = self.buf + if buf is not None: + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + return buf.__len__() + else: + return self.strbuf.__len__() + + def __nonzero__(self): + # use self.__len__ rather than len(self) FBO of not getting + # OverflowError on Python 2 + return self.__len__() > 0 + + __bool__ = __nonzero__ # py3 + + def _create_buffer(self): + strbuf = self.strbuf + if len(strbuf) >= self.overflow: + self._set_large_buffer() + else: + self._set_small_buffer() + buf = self.buf + if strbuf: + buf.append(self.strbuf) + self.strbuf = b"" + return buf + + def _set_small_buffer(self): + self.buf = BytesIOBasedBuffer(self.buf) + self.overflowed = False + + def _set_large_buffer(self): + self.buf = TempfileBasedBuffer(self.buf) + self.overflowed = True + + def append(self, s): + buf = self.buf + if buf is None: + strbuf = self.strbuf + if len(strbuf) + len(s) < STRBUF_LIMIT: + self.strbuf = strbuf + s + return + buf = self._create_buffer() + buf.append(s) + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + sz = buf.__len__() + if not self.overflowed: + if sz >= self.overflow: + self._set_large_buffer() + + def get(self, numbytes=-1, skip=False): + buf = self.buf + if buf is None: + strbuf = self.strbuf + if not skip: + return strbuf + buf = self._create_buffer() + return buf.get(numbytes, skip) + + def skip(self, numbytes, allow_prune=False): + buf = self.buf + if buf is None: + if allow_prune and numbytes == len(self.strbuf): + # We could slice instead of converting to + # a buffer, but that would eat up memory in + # large transfers. + self.strbuf = b"" + return + buf = self._create_buffer() + buf.skip(numbytes, allow_prune) + + def prune(self): + """ + A potentially expensive operation that removes all data + already retrieved from the buffer. + """ + buf = self.buf + if buf is None: + self.strbuf = b"" + return + buf.prune() + if self.overflowed: + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + sz = buf.__len__() + if sz < self.overflow: + # Revert to a faster buffer. + self._set_small_buffer() + + def getfile(self): + buf = self.buf + if buf is None: + buf = self._create_buffer() + return buf.getfile() + + def close(self): + buf = self.buf + if buf is not None: + buf.close() diff --git a/src/waitress/channel.py b/src/waitress/channel.py new file mode 100644 index 0000000..a8bc76f --- /dev/null +++ b/src/waitress/channel.py @@ -0,0 +1,414 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +import socket +import threading +import time +import traceback + +from waitress.buffers import ( + OverflowableBuffer, + ReadOnlyFileBasedBuffer, +) + +from waitress.parser import HTTPRequestParser + +from waitress.task import ( + ErrorTask, + WSGITask, +) + +from waitress.utilities import InternalServerError + +from . import wasyncore + + +class ClientDisconnected(Exception): + """ Raised when attempting to write to a closed socket.""" + + +class HTTPChannel(wasyncore.dispatcher, object): + """ + Setting self.requests = [somerequest] prevents more requests from being + received until the out buffers have been flushed. + + Setting self.requests = [] allows more requests to be received. + """ + + task_class = WSGITask + error_task_class = ErrorTask + parser_class = HTTPRequestParser + + request = None # A request parser instance + last_activity = 0 # Time of last activity + will_close = False # set to True to close the socket. + close_when_flushed = False # set to True to close the socket when flushed + requests = () # currently pending requests + sent_continue = False # used as a latch after sending 100 continue + total_outbufs_len = 0 # total bytes ready to send + current_outbuf_count = 0 # total bytes written to current outbuf + + # + # ASYNCHRONOUS METHODS (including __init__) + # + + def __init__( + self, server, sock, addr, adj, map=None, + ): + self.server = server + self.adj = adj + self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)] + self.creation_time = self.last_activity = time.time() + self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) + + # task_lock used to push/pop requests + self.task_lock = threading.Lock() + # outbuf_lock used to access any outbuf (expected to use an RLock) + self.outbuf_lock = threading.Condition() + + wasyncore.dispatcher.__init__(self, sock, map=map) + + # Don't let wasyncore.dispatcher throttle self.addr on us. + self.addr = addr + + def writable(self): + # if there's data in the out buffer or we've been instructed to close + # the channel (possibly by our server maintenance logic), run + # handle_write + return self.total_outbufs_len or self.will_close or self.close_when_flushed + + def handle_write(self): + # Precondition: there's data in the out buffer to be sent, or + # there's a pending will_close request + if not self.connected: + # we dont want to close the channel twice + return + + # try to flush any pending output + if not self.requests: + # 1. There are no running tasks, so we don't need to try to lock + # the outbuf before sending + # 2. The data in the out buffer should be sent as soon as possible + # because it's either data left over from task output + # or a 100 Continue line sent within "received". + flush = self._flush_some + elif self.total_outbufs_len >= self.adj.send_bytes: + # 1. There's a running task, so we need to try to lock + # the outbuf before sending + # 2. Only try to send if the data in the out buffer is larger + # than self.adj_bytes to avoid TCP fragmentation + flush = self._flush_some_if_lockable + else: + # 1. There's not enough data in the out buffer to bother to send + # right now. + flush = None + + if flush: + try: + flush() + except socket.error: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.will_close = True + except Exception: + self.logger.exception("Unexpected exception when flushing") + self.will_close = True + + if self.close_when_flushed and not self.total_outbufs_len: + self.close_when_flushed = False + self.will_close = True + + if self.will_close: + self.handle_close() + + def readable(self): + # We might want to create a new task. We can only do this if: + # 1. We're not already about to close the connection. + # 2. There's no already currently running task(s). + # 3. There's no data in the output buffer that needs to be sent + # before we potentially create a new task. + return not (self.will_close or self.requests or self.total_outbufs_len) + + def handle_read(self): + try: + data = self.recv(self.adj.recv_bytes) + except socket.error: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.handle_close() + return + if data: + self.last_activity = time.time() + self.received(data) + + def received(self, data): + """ + Receives input asynchronously and assigns one or more requests to the + channel. + """ + # Preconditions: there's no task(s) already running + request = self.request + requests = [] + + if not data: + return False + + while data: + if request is None: + request = self.parser_class(self.adj) + n = request.received(data) + if request.expect_continue and request.headers_finished: + # guaranteed by parser to be a 1.1 request + request.expect_continue = False + if not self.sent_continue: + # there's no current task, so we don't need to try to + # lock the outbuf to append to it. + outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" + self.outbufs[-1].append(outbuf_payload) + self.current_outbuf_count += len(outbuf_payload) + self.total_outbufs_len += len(outbuf_payload) + self.sent_continue = True + self._flush_some() + request.completed = False + if request.completed: + # The request (with the body) is ready to use. + self.request = None + if not request.empty: + requests.append(request) + request = None + else: + self.request = request + if n >= len(data): + break + data = data[n:] + + if requests: + self.requests = requests + self.server.add_task(self) + + return True + + def _flush_some_if_lockable(self): + # Since our task may be appending to the outbuf, we try to acquire + # the lock, but we don't block if we can't. + if self.outbuf_lock.acquire(False): + try: + self._flush_some() + + if self.total_outbufs_len < self.adj.outbuf_high_watermark: + self.outbuf_lock.notify() + finally: + self.outbuf_lock.release() + + def _flush_some(self): + # Send as much data as possible to our client + + sent = 0 + dobreak = False + + while True: + outbuf = self.outbufs[0] + # use outbuf.__len__ rather than len(outbuf) FBO of not getting + # OverflowError on 32-bit Python + outbuflen = outbuf.__len__() + while outbuflen > 0: + chunk = outbuf.get(self.sendbuf_len) + num_sent = self.send(chunk) + if num_sent: + outbuf.skip(num_sent, True) + outbuflen -= num_sent + sent += num_sent + self.total_outbufs_len -= num_sent + else: + # failed to write anything, break out entirely + dobreak = True + break + else: + # self.outbufs[-1] must always be a writable outbuf + if len(self.outbufs) > 1: + toclose = self.outbufs.pop(0) + try: + toclose.close() + except Exception: + self.logger.exception("Unexpected error when closing an outbuf") + else: + # caught up, done flushing for now + dobreak = True + + if dobreak: + break + + if sent: + self.last_activity = time.time() + return True + + return False + + def handle_close(self): + with self.outbuf_lock: + for outbuf in self.outbufs: + try: + outbuf.close() + except Exception: + self.logger.exception( + "Unknown exception while trying to close outbuf" + ) + self.total_outbufs_len = 0 + self.connected = False + self.outbuf_lock.notify() + wasyncore.dispatcher.close(self) + + def add_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of opened channels. + """ + wasyncore.dispatcher.add_channel(self, map) + self.server.active_channels[self._fileno] = self + + def del_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of closed channels. + """ + fd = self._fileno # next line sets this to None + wasyncore.dispatcher.del_channel(self, map) + ac = self.server.active_channels + if fd in ac: + del ac[fd] + + # + # SYNCHRONOUS METHODS + # + + def write_soon(self, data): + if not self.connected: + # if the socket is closed then interrupt the task so that it + # can cleanup possibly before the app_iter is exhausted + raise ClientDisconnected + if data: + # the async mainloop might be popping data off outbuf; we can + # block here waiting for it because we're in a task thread + with self.outbuf_lock: + self._flush_outbufs_below_high_watermark() + if not self.connected: + raise ClientDisconnected + num_bytes = len(data) + if data.__class__ is ReadOnlyFileBasedBuffer: + # they used wsgi.file_wrapper + self.outbufs.append(data) + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + else: + if self.current_outbuf_count > self.adj.outbuf_high_watermark: + # rotate to a new buffer if the current buffer has hit + # the watermark to avoid it growing unbounded + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + self.outbufs[-1].append(data) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + if self.total_outbufs_len >= self.adj.send_bytes: + self.server.pull_trigger() + return num_bytes + return 0 + + def _flush_outbufs_below_high_watermark(self): + # check first to avoid locking if possible + if self.total_outbufs_len > self.adj.outbuf_high_watermark: + with self.outbuf_lock: + while ( + self.connected + and self.total_outbufs_len > self.adj.outbuf_high_watermark + ): + self.server.pull_trigger() + self.outbuf_lock.wait() + + def service(self): + """Execute all pending requests """ + with self.task_lock: + while self.requests: + request = self.requests[0] + if request.error: + task = self.error_task_class(self, request) + else: + task = self.task_class(self, request) + try: + task.service() + except ClientDisconnected: + self.logger.info( + "Client disconnected while serving %s" % task.request.path + ) + task.close_on_finish = True + except Exception: + self.logger.exception( + "Exception while serving %s" % task.request.path + ) + if not task.wrote_header: + if self.adj.expose_tracebacks: + body = traceback.format_exc() + else: + body = ( + "The server encountered an unexpected " + "internal server error" + ) + req_version = request.version + req_headers = request.headers + request = self.parser_class(self.adj) + request.error = InternalServerError(body) + # copy some original request attributes to fulfill + # HTTP 1.1 requirements + request.version = req_version + try: + request.headers["CONNECTION"] = req_headers["CONNECTION"] + except KeyError: + pass + task = self.error_task_class(self, request) + try: + task.service() # must not fail + except ClientDisconnected: + task.close_on_finish = True + else: + task.close_on_finish = True + # we cannot allow self.requests to drop to empty til + # here; otherwise the mainloop gets confused + if task.close_on_finish: + self.close_when_flushed = True + for request in self.requests: + request.close() + self.requests = [] + else: + # before processing a new request, ensure there is not too + # much data in the outbufs waiting to be flushed + # NB: currently readable() returns False while we are + # flushing data so we know no new requests will come in + # that we need to account for, otherwise it'd be better + # to do this check at the start of the request instead of + # at the end to account for consecutive service() calls + if len(self.requests) > 1: + self._flush_outbufs_below_high_watermark() + request = self.requests.pop(0) + request.close() + + if self.connected: + self.server.pull_trigger() + self.last_activity = time.time() + + def cancel(self): + """ Cancels all pending / active requests """ + self.will_close = True + self.connected = False + self.last_activity = time.time() + self.requests = [] diff --git a/src/waitress/compat.py b/src/waitress/compat.py new file mode 100644 index 0000000..fe72a76 --- /dev/null +++ b/src/waitress/compat.py @@ -0,0 +1,179 @@ +import os +import sys +import types +import platform +import warnings + +try: + import urlparse +except ImportError: # pragma: no cover + from urllib import parse as urlparse + +try: + import fcntl +except ImportError: # pragma: no cover + fcntl = None # windows + +# True if we are running on Python 3. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +# True if we are running on Windows +WIN = platform.system() == "Windows" + +if PY3: # pragma: no cover + string_types = (str,) + integer_types = (int,) + class_types = (type,) + text_type = str + binary_type = bytes + long = int +else: + string_types = (basestring,) + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + long = long + +if PY3: # pragma: no cover + from urllib.parse import unquote_to_bytes + + def unquote_bytes_to_wsgi(bytestring): + return unquote_to_bytes(bytestring).decode("latin-1") + + +else: + from urlparse import unquote as unquote_to_bytes + + def unquote_bytes_to_wsgi(bytestring): + return unquote_to_bytes(bytestring) + + +def text_(s, encoding="latin-1", errors="strict"): + """ If ``s`` is an instance of ``binary_type``, return + ``s.decode(encoding, errors)``, otherwise return ``s``""" + if isinstance(s, binary_type): + return s.decode(encoding, errors) + return s # pragma: no cover + + +if PY3: # pragma: no cover + + def tostr(s): + if isinstance(s, text_type): + s = s.encode("latin-1") + return str(s, "latin-1", "strict") + + def tobytes(s): + return bytes(s, "latin-1") + + +else: + tostr = str + + def tobytes(s): + return s + + +if PY3: # pragma: no cover + import builtins + + exec_ = getattr(builtins, "exec") + + def reraise(tp, value, tb=None): + if value is None: + value = tp + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + del builtins + +else: # pragma: no cover + + def exec_(code, globs=None, locs=None): + """Execute code in a namespace.""" + if globs is None: + frame = sys._getframe(1) + globs = frame.f_globals + if locs is None: + locs = frame.f_locals + del frame + elif locs is None: + locs = globs + exec("""exec code in globs, locs""") + + exec_( + """def reraise(tp, value, tb=None): + raise tp, value, tb +""" + ) + +try: + from StringIO import StringIO as NativeIO +except ImportError: # pragma: no cover + from io import StringIO as NativeIO + +try: + import httplib +except ImportError: # pragma: no cover + from http import client as httplib + +try: + MAXINT = sys.maxint +except AttributeError: # pragma: no cover + MAXINT = sys.maxsize + + +# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, +# Python on Windows may not define IPPROTO_IPV6 in socket. +import socket + +HAS_IPV6 = socket.has_ipv6 + +if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"): + IPPROTO_IPV6 = socket.IPPROTO_IPV6 + IPV6_V6ONLY = socket.IPV6_V6ONLY +else: # pragma: no cover + if WIN: + IPPROTO_IPV6 = 41 + IPV6_V6ONLY = 27 + else: + warnings.warn( + "OS does not support required IPv6 socket flags. This is requirement " + "for Waitress. Please open an issue at https://github.com/Pylons/waitress. " + "IPv6 support has been disabled.", + RuntimeWarning, + ) + HAS_IPV6 = False + + +def set_nonblocking(fd): # pragma: no cover + if PY3 and sys.version_info[1] >= 5: + os.set_blocking(fd, False) + elif fcntl is None: + raise RuntimeError("no fcntl module present") + else: + flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +if PY3: + ResourceWarning = ResourceWarning +else: + ResourceWarning = UserWarning + + +def qualname(cls): + if PY3: + return cls.__qualname__ + return cls.__name__ + + +try: + import thread +except ImportError: + # py3 + import _thread as thread diff --git a/src/waitress/parser.py b/src/waitress/parser.py new file mode 100644 index 0000000..53072b5 --- /dev/null +++ b/src/waitress/parser.py @@ -0,0 +1,413 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser + +This server uses asyncore to accept connections and do initial +processing but threads to do work. +""" +import re +from io import BytesIO + +from waitress.buffers import OverflowableBuffer +from waitress.compat import tostr, unquote_bytes_to_wsgi, urlparse +from waitress.receiver import ChunkedReceiver, FixedStreamReceiver +from waitress.utilities import ( + BadRequest, + RequestEntityTooLarge, + RequestHeaderFieldsTooLarge, + ServerNotImplemented, + find_double_newline, +) +from .rfc7230 import HEADER_FIELD + + +class ParsingError(Exception): + pass + + +class TransferEncodingNotImplemented(Exception): + pass + +class HTTPRequestParser(object): + """A structure that collects the HTTP request. + + Once the stream is completed, the instance is passed to + a server task constructor. + """ + + completed = False # Set once request is completed. + empty = False # Set if no request was made. + expect_continue = False # client sent "Expect: 100-continue" header + headers_finished = False # True when headers have been read + header_plus = b"" + chunked = False + content_length = 0 + header_bytes_received = 0 + body_bytes_received = 0 + body_rcv = None + version = "1.0" + error = None + connection_close = False + + # Other attributes: first_line, header, headers, command, uri, version, + # path, query, fragment + + def __init__(self, adj): + """ + adj is an Adjustments object. + """ + # headers is a mapping containing keys translated to uppercase + # with dashes turned into underscores. + self.headers = {} + self.adj = adj + + def received(self, data): + """ + Receives the HTTP stream for one request. Returns the number of + bytes consumed. Sets the completed flag once both the header and the + body have been received. + """ + if self.completed: + return 0 # Can't consume any more. + + datalen = len(data) + br = self.body_rcv + if br is None: + # In header. + max_header = self.adj.max_request_header_size + + s = self.header_plus + data + index = find_double_newline(s) + consumed = 0 + + if index >= 0: + # If the headers have ended, and we also have part of the body + # message in data we still want to validate we aren't going + # over our limit for received headers. + self.header_bytes_received += index + consumed = datalen - (len(s) - index) + else: + self.header_bytes_received += datalen + consumed = datalen + + # If the first line + headers is over the max length, we return a + # RequestHeaderFieldsTooLarge error rather than continuing to + # attempt to parse the headers. + if self.header_bytes_received >= max_header: + self.parse_header(b"GET / HTTP/1.0\r\n") + self.error = RequestHeaderFieldsTooLarge( + "exceeds max_header of %s" % max_header + ) + self.completed = True + return consumed + + if index >= 0: + # Header finished. + header_plus = s[:index] + + # Remove preceeding blank lines. This is suggested by + # https://tools.ietf.org/html/rfc7230#section-3.5 to support + # clients sending an extra CR LF after another request when + # using HTTP pipelining + header_plus = header_plus.lstrip() + + if not header_plus: + self.empty = True + self.completed = True + else: + try: + self.parse_header(header_plus) + except ParsingError as e: + self.error = BadRequest(e.args[0]) + self.completed = True + except TransferEncodingNotImplemented as e: + self.error = ServerNotImplemented(e.args[0]) + self.completed = True + else: + if self.body_rcv is None: + # no content-length header and not a t-e: chunked + # request + self.completed = True + + if self.content_length > 0: + max_body = self.adj.max_request_body_size + # we won't accept this request if the content-length + # is too large + + if self.content_length >= max_body: + self.error = RequestEntityTooLarge( + "exceeds max_body of %s" % max_body + ) + self.completed = True + self.headers_finished = True + + return consumed + + # Header not finished yet. + self.header_plus = s + + return datalen + else: + # In body. + consumed = br.received(data) + self.body_bytes_received += consumed + max_body = self.adj.max_request_body_size + + if self.body_bytes_received >= max_body: + # this will only be raised during t-e: chunked requests + self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body) + self.completed = True + elif br.error: + # garbage in chunked encoding input probably + self.error = br.error + self.completed = True + elif br.completed: + # The request (with the body) is ready to use. + self.completed = True + + if self.chunked: + # We've converted the chunked transfer encoding request + # body into a normal request body, so we know its content + # length; set the header here. We already popped the + # TRANSFER_ENCODING header in parse_header, so this will + # appear to the client to be an entirely non-chunked HTTP + # request with a valid content-length. + self.headers["CONTENT_LENGTH"] = str(br.__len__()) + + return consumed + + def parse_header(self, header_plus): + """ + Parses the header_plus block of text (the headers plus the + first line of the request). + """ + index = header_plus.find(b"\r\n") + if index >= 0: + first_line = header_plus[:index].rstrip() + header = header_plus[index + 2 :] + else: + raise ParsingError("HTTP message header invalid") + + if b"\r" in first_line or b"\n" in first_line: + raise ParsingError("Bare CR or LF found in HTTP message") + + self.first_line = first_line # for testing + + lines = get_header_lines(header) + + headers = self.headers + for line in lines: + header = HEADER_FIELD.match(line) + + if not header: + raise ParsingError("Invalid header") + + key, value = header.group("name", "value") + + if b"_" in key: + # TODO(xistence): Should we drop this request instead? + continue + + # Only strip off whitespace that is considered valid whitespace by + # RFC7230, don't strip the rest + value = value.strip(b" \t") + key1 = tostr(key.upper().replace(b"-", b"_")) + # If a header already exists, we append subsequent values + # separated by a comma. Applications already need to handle + # the comma separated values, as HTTP front ends might do + # the concatenation for you (behavior specified in RFC2616). + try: + headers[key1] += tostr(b", " + value) + except KeyError: + headers[key1] = tostr(value) + + # command, uri, version will be bytes + command, uri, version = crack_first_line(first_line) + version = tostr(version) + command = tostr(command) + self.command = command + self.version = version + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + self.url_scheme = self.adj.url_scheme + connection = headers.get("CONNECTION", "") + + if version == "1.0": + if connection.lower() != "keep-alive": + self.connection_close = True + + if version == "1.1": + # since the server buffers data from chunked transfers and clients + # never need to deal with chunked requests, downstream clients + # should not see the HTTP_TRANSFER_ENCODING header; we pop it + # here + te = headers.pop("TRANSFER_ENCODING", "") + + # NB: We can not just call bare strip() here because it will also + # remove other non-printable characters that we explicitly do not + # want removed so that if someone attempts to smuggle a request + # with these characters we don't fall prey to it. + # + # For example \x85 is stripped by default, but it is not considered + # valid whitespace to be stripped by RFC7230. + encodings = [ + encoding.strip(" \t").lower() for encoding in te.split(",") if encoding + ] + + for encoding in encodings: + # Out of the transfer-codings listed in + # https://tools.ietf.org/html/rfc7230#section-4 we only support + # chunked at this time. + + # Note: the identity transfer-coding was removed in RFC7230: + # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus + # not supported + if encoding not in {"chunked"}: + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + if encodings and encodings[-1] == "chunked": + self.chunked = True + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = ChunkedReceiver(buf) + elif encodings: # pragma: nocover + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + expect = headers.get("EXPECT", "").lower() + self.expect_continue = expect == "100-continue" + if connection.lower() == "close": + self.connection_close = True + + if not self.chunked: + try: + cl = int(headers.get("CONTENT_LENGTH", 0)) + except ValueError: + raise ParsingError("Content-Length is invalid") + + self.content_length = cl + if cl > 0: + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = FixedStreamReceiver(cl, buf) + + def get_body_stream(self): + body_rcv = self.body_rcv + if body_rcv is not None: + return body_rcv.getfile() + else: + return BytesIO() + + def close(self): + body_rcv = self.body_rcv + if body_rcv is not None: + body_rcv.getbuf().close() + + +def split_uri(uri): + # urlsplit handles byte input by returning bytes on py3, so + # scheme, netloc, path, query, and fragment are bytes + + scheme = netloc = path = query = fragment = b"" + + # urlsplit below will treat this as a scheme-less netloc, thereby losing + # the original intent of the request. Here we shamelessly stole 4 lines of + # code from the CPython stdlib to parse out the fragment and query but + # leave the path alone. See + # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468 + # and https://github.com/Pylons/waitress/issues/260 + + if uri[:2] == b"//": + path = uri + + if b"#" in path: + path, fragment = path.split(b"#", 1) + + if b"?" in path: + path, query = path.split(b"?", 1) + else: + try: + scheme, netloc, path, query, fragment = urlparse.urlsplit(uri) + except UnicodeError: + raise ParsingError("Bad URI") + + return ( + tostr(scheme), + tostr(netloc), + unquote_bytes_to_wsgi(path), + tostr(query), + tostr(fragment), + ) + + +def get_header_lines(header): + """ + Splits the header into lines, putting multi-line headers together. + """ + r = [] + lines = header.split(b"\r\n") + for line in lines: + if not line: + continue + + if b"\r" in line or b"\n" in line: + raise ParsingError('Bare CR or LF found in header line "%s"' % tostr(line)) + + if line.startswith((b" ", b"\t")): + if not r: + # https://corte.si/posts/code/pathod/pythonservers/index.html + raise ParsingError('Malformed header line "%s"' % tostr(line)) + r[-1] += line + else: + r.append(line) + return r + + +first_line_re = re.compile( + b"([^ ]+) " + b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)" + b"(( HTTP/([0-9.]+))$|$)" +) + + +def crack_first_line(line): + m = first_line_re.match(line) + if m is not None and m.end() == len(line): + if m.group(3): + version = m.group(5) + else: + version = b"" + method = m.group(1) + + # the request methods that are currently defined are all uppercase: + # https://www.iana.org/assignments/http-methods/http-methods.xhtml and + # the request method is case sensitive according to + # https://tools.ietf.org/html/rfc7231#section-4.1 + + # By disallowing anything but uppercase methods we save poor + # unsuspecting souls from sending lowercase HTTP methods to waitress + # and having the request complete, while servers like nginx drop the + # request onto the floor. + if method != method.upper(): + raise ParsingError('Malformed HTTP method "%s"' % tostr(method)) + uri = m.group(2) + return method, uri, version + else: + return b"", b"", b"" diff --git a/src/waitress/proxy_headers.py b/src/waitress/proxy_headers.py new file mode 100644 index 0000000..1df8b8e --- /dev/null +++ b/src/waitress/proxy_headers.py @@ -0,0 +1,333 @@ +from collections import namedtuple + +from .utilities import logger, undquote, BadRequest + + +PROXY_HEADERS = frozenset( + { + "X_FORWARDED_FOR", + "X_FORWARDED_HOST", + "X_FORWARDED_PROTO", + "X_FORWARDED_PORT", + "X_FORWARDED_BY", + "FORWARDED", + } +) + +Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"]) + + +class MalformedProxyHeader(Exception): + def __init__(self, header, reason, value): + self.header = header + self.reason = reason + self.value = value + super(MalformedProxyHeader, self).__init__(header, reason, value) + + +def proxy_headers_middleware( + app, + trusted_proxy=None, + trusted_proxy_count=1, + trusted_proxy_headers=None, + clear_untrusted=True, + log_untrusted=False, + logger=logger, +): + def translate_proxy_headers(environ, start_response): + untrusted_headers = PROXY_HEADERS + remote_peer = environ["REMOTE_ADDR"] + if trusted_proxy == "*" or remote_peer == trusted_proxy: + try: + untrusted_headers = parse_proxy_headers( + environ, + trusted_proxy_count=trusted_proxy_count, + trusted_proxy_headers=trusted_proxy_headers, + logger=logger, + ) + except MalformedProxyHeader as ex: + logger.warning( + 'Malformed proxy header "%s" from "%s": %s value: %s', + ex.header, + remote_peer, + ex.reason, + ex.value, + ) + error = BadRequest('Header "{0}" malformed.'.format(ex.header)) + return error.wsgi_response(environ, start_response) + + # Clear out the untrusted proxy headers + if clear_untrusted: + clear_untrusted_headers( + environ, untrusted_headers, log_warning=log_untrusted, logger=logger, + ) + + return app(environ, start_response) + + return translate_proxy_headers + + +def parse_proxy_headers( + environ, trusted_proxy_count, trusted_proxy_headers, logger=logger, +): + if trusted_proxy_headers is None: + trusted_proxy_headers = set() + + forwarded_for = [] + forwarded_host = forwarded_proto = forwarded_port = forwarded = "" + client_addr = None + untrusted_headers = set(PROXY_HEADERS) + + def raise_for_multiple_values(): + raise ValueError("Unspecified behavior for multiple values found in header",) + + if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ: + try: + forwarded_for = [] + + for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","): + forward_hop = forward_hop.strip() + forward_hop = undquote(forward_hop) + + # Make sure that all IPv6 addresses are surrounded by brackets, + # this is assuming that the IPv6 representation here does not + # include a port number. + + if "." not in forward_hop and ( + ":" in forward_hop and forward_hop[-1] != "]" + ): + forwarded_for.append("[{}]".format(forward_hop)) + else: + forwarded_for.append(forward_hop) + + forwarded_for = forwarded_for[-trusted_proxy_count:] + client_addr = forwarded_for[0] + + untrusted_headers.remove("X_FORWARDED_FOR") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"], + ) + + if ( + "x-forwarded-host" in trusted_proxy_headers + and "HTTP_X_FORWARDED_HOST" in environ + ): + try: + forwarded_host_multiple = [] + + for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","): + forward_host = forward_host.strip() + forward_host = undquote(forward_host) + forwarded_host_multiple.append(forward_host) + + forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:] + forwarded_host = forwarded_host_multiple[0] + + untrusted_headers.remove("X_FORWARDED_HOST") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"], + ) + + if "x-forwarded-proto" in trusted_proxy_headers: + try: + forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", "")) + if "," in forwarded_proto: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PROTO") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"], + ) + + if "x-forwarded-port" in trusted_proxy_headers: + try: + forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", "")) + if "," in forwarded_port: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PORT") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"], + ) + + if "x-forwarded-by" in trusted_proxy_headers: + # Waitress itself does not use X-Forwarded-By, but we can not + # remove it so it can get set in the environ + untrusted_headers.remove("X_FORWARDED_BY") + + if "forwarded" in trusted_proxy_headers: + forwarded = environ.get("HTTP_FORWARDED", None) + untrusted_headers = PROXY_HEADERS - {"FORWARDED"} + + # If the Forwarded header exists, it gets priority + if forwarded: + proxies = [] + try: + for forwarded_element in forwarded.split(","): + # Remove whitespace that may have been introduced when + # appending a new entry + forwarded_element = forwarded_element.strip() + + forwarded_for = forwarded_host = forwarded_proto = "" + forwarded_port = forwarded_by = "" + + for pair in forwarded_element.split(";"): + pair = pair.lower() + + if not pair: + continue + + token, equals, value = pair.partition("=") + + if equals != "=": + raise ValueError('Invalid forwarded-pair missing "="') + + if token.strip() != token: + raise ValueError("Token may not be surrounded by whitespace") + + if value.strip() != value: + raise ValueError("Value may not be surrounded by whitespace") + + if token == "by": + forwarded_by = undquote(value) + + elif token == "for": + forwarded_for = undquote(value) + + elif token == "host": + forwarded_host = undquote(value) + + elif token == "proto": + forwarded_proto = undquote(value) + + else: + logger.warning("Unknown Forwarded token: %s" % token) + + proxies.append( + Forwarded( + forwarded_by, forwarded_for, forwarded_host, forwarded_proto + ) + ) + except Exception as ex: + raise MalformedProxyHeader( + "Forwarded", str(ex), environ["HTTP_FORWARDED"], + ) + + proxies = proxies[-trusted_proxy_count:] + + # Iterate backwards and fill in some values, the oldest entry that + # contains the information we expect is the one we use. We expect + # that intermediate proxies may re-write the host header or proto, + # but the oldest entry is the one that contains the information the + # client expects when generating URL's + # + # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https" + # Forwarded: for=192.0.2.1;host="example.internal:8080" + # + # (After HTTPS header folding) should mean that we use as values: + # + # Host: example.com + # Protocol: https + # Port: 8443 + + for proxy in proxies[::-1]: + client_addr = proxy.for_ or client_addr + forwarded_host = proxy.host or forwarded_host + forwarded_proto = proxy.proto or forwarded_proto + + if forwarded_proto: + forwarded_proto = forwarded_proto.lower() + + if forwarded_proto not in {"http", "https"}: + raise MalformedProxyHeader( + "Forwarded Proto=" if forwarded else "X-Forwarded-Proto", + "unsupported proto value", + forwarded_proto, + ) + + # Set the URL scheme to the proxy provided proto + environ["wsgi.url_scheme"] = forwarded_proto + + if not forwarded_port: + if forwarded_proto == "http": + forwarded_port = "80" + + if forwarded_proto == "https": + forwarded_port = "443" + + if forwarded_host: + if ":" in forwarded_host and forwarded_host[-1] != "]": + host, port = forwarded_host.rsplit(":", 1) + host, port = host.strip(), str(port) + + # We trust the port in the Forwarded Host/X-Forwarded-Host over + # X-Forwarded-Port, or whatever we got from Forwarded + # Proto/X-Forwarded-Proto. + + if forwarded_port != port: + forwarded_port = port + + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = host + environ["HTTP_HOST"] = forwarded_host + else: + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = forwarded_host + environ["HTTP_HOST"] = forwarded_host + + if forwarded_port: + if forwarded_port not in {"443", "80"}: + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + + if forwarded_port: + environ["SERVER_PORT"] = str(forwarded_port) + + if client_addr: + if ":" in client_addr and client_addr[-1] != "]": + addr, port = client_addr.rsplit(":", 1) + environ["REMOTE_ADDR"] = strip_brackets(addr.strip()) + environ["REMOTE_PORT"] = port.strip() + else: + environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip()) + environ["REMOTE_HOST"] = environ["REMOTE_ADDR"] + + return untrusted_headers + + +def strip_brackets(addr): + if addr[0] == "[" and addr[-1] == "]": + return addr[1:-1] + return addr + + +def clear_untrusted_headers( + environ, untrusted_headers, log_warning=False, logger=logger +): + untrusted_headers_removed = [ + header + for header in untrusted_headers + if environ.pop("HTTP_" + header, False) is not False + ] + + if log_warning and untrusted_headers_removed: + untrusted_headers_removed = [ + "-".join(x.capitalize() for x in header.split("_")) + for header in untrusted_headers_removed + ] + logger.warning( + "Removed untrusted headers (%s). Waitress recommends these be " + "removed upstream.", + ", ".join(untrusted_headers_removed), + ) diff --git a/src/waitress/receiver.py b/src/waitress/receiver.py new file mode 100644 index 0000000..5d1568d --- /dev/null +++ b/src/waitress/receiver.py @@ -0,0 +1,186 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Data Chunk Receiver +""" + +from waitress.utilities import BadRequest, find_double_newline + + +class FixedStreamReceiver(object): + + # See IStreamConsumer + completed = False + error = None + + def __init__(self, cl, buf): + self.remain = cl + self.buf = buf + + def __len__(self): + return self.buf.__len__() + + def received(self, data): + "See IStreamConsumer" + rm = self.remain + + if rm < 1: + self.completed = True # Avoid any chance of spinning + + return 0 + datalen = len(data) + + if rm <= datalen: + self.buf.append(data[:rm]) + self.remain = 0 + self.completed = True + + return rm + else: + self.buf.append(data) + self.remain -= datalen + + return datalen + + def getfile(self): + return self.buf.getfile() + + def getbuf(self): + return self.buf + + +class ChunkedReceiver(object): + + chunk_remainder = 0 + validate_chunk_end = False + control_line = b"" + chunk_end = b"" + all_chunks_received = False + trailer = b"" + completed = False + error = None + + # max_control_line = 1024 + # max_trailer = 65536 + + def __init__(self, buf): + self.buf = buf + + def __len__(self): + return self.buf.__len__() + + def received(self, s): + # Returns the number of bytes consumed. + + if self.completed: + return 0 + orig_size = len(s) + + while s: + rm = self.chunk_remainder + + if rm > 0: + # Receive the remainder of a chunk. + to_write = s[:rm] + self.buf.append(to_write) + written = len(to_write) + s = s[written:] + + self.chunk_remainder -= written + + if self.chunk_remainder == 0: + self.validate_chunk_end = True + elif self.validate_chunk_end: + s = self.chunk_end + s + + pos = s.find(b"\r\n") + + if pos < 0 and len(s) < 2: + self.chunk_end = s + s = b"" + else: + self.chunk_end = b"" + if pos == 0: + # Chop off the terminating CR LF from the chunk + s = s[2:] + else: + self.error = BadRequest("Chunk not properly terminated") + self.all_chunks_received = True + + # Always exit this loop + self.validate_chunk_end = False + elif not self.all_chunks_received: + # Receive a control line. + s = self.control_line + s + pos = s.find(b"\r\n") + + if pos < 0: + # Control line not finished. + self.control_line = s + s = b"" + else: + # Control line finished. + line = s[:pos] + s = s[pos + 2 :] + self.control_line = b"" + line = line.strip() + + if line: + # Begin a new chunk. + semi = line.find(b";") + + if semi >= 0: + # discard extension info. + line = line[:semi] + try: + sz = int(line.strip(), 16) # hexadecimal + except ValueError: # garbage in input + self.error = BadRequest("garbage in chunked encoding input") + sz = 0 + + if sz > 0: + # Start a new chunk. + self.chunk_remainder = sz + else: + # Finished chunks. + self.all_chunks_received = True + # else expect a control line. + else: + # Receive the trailer. + trailer = self.trailer + s + + if trailer.startswith(b"\r\n"): + # No trailer. + self.completed = True + + return orig_size - (len(trailer) - 2) + pos = find_double_newline(trailer) + + if pos < 0: + # Trailer not finished. + self.trailer = trailer + s = b"" + else: + # Finished the trailer. + self.completed = True + self.trailer = trailer[:pos] + + return orig_size - (len(trailer) - pos) + + return orig_size + + def getfile(self): + return self.buf.getfile() + + def getbuf(self): + return self.buf diff --git a/src/waitress/rfc7230.py b/src/waitress/rfc7230.py new file mode 100644 index 0000000..cd33c90 --- /dev/null +++ b/src/waitress/rfc7230.py @@ -0,0 +1,52 @@ +""" +This contains a bunch of RFC7230 definitions and regular expressions that are +needed to properly parse HTTP messages. +""" + +import re + +from .compat import tobytes + +WS = "[ \t]" +OWS = WS + "{0,}?" +RWS = WS + "{1,}?" +BWS = OWS + +# RFC 7230 Section 3.2.6 "Field Value Components": +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# obs-text = %x80-FF +TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]" +OBS_TEXT = r"\x80-\xff" + +TOKEN = TCHAR + "{1,}" + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +VCHAR = r"\x21-\x7e" + +# header-field = field-name ":" OWS field-value OWS +# field-name = token +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text + +# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 +# changes field-content to: +# +# field-content = field-vchar [ 1*( SP / HTAB / field-vchar ) +# field-vchar ] + +FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]" +# Field content is more greedy than the ABNF, in that it will match the whole value +FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*" +# Which allows the field value here to just see if there is even a value in the first place +FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?" + +HEADER_FIELD = re.compile( + tobytes( + "^(?P" + TOKEN + "):" + OWS + "(?P" + FIELD_VALUE + ")" + OWS + "$" + ) +) diff --git a/src/waitress/runner.py b/src/waitress/runner.py new file mode 100644 index 0000000..2495084 --- /dev/null +++ b/src/waitress/runner.py @@ -0,0 +1,286 @@ +############################################################################## +# +# Copyright (c) 2013 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Command line runner. +""" + +from __future__ import print_function, unicode_literals + +import getopt +import os +import os.path +import re +import sys + +from waitress import serve +from waitress.adjustments import Adjustments + +HELP = """\ +Usage: + + {0} [OPTS] MODULE:OBJECT + +Standard options: + + --help + Show this information. + + --call + Call the given object to get the WSGI application. + + --host=ADDR + Hostname or IP address on which to listen, default is '0.0.0.0', + which means "all IP addresses on this host". + + Note: May not be used together with --listen + + --port=PORT + TCP port on which to listen, default is '8080' + + Note: May not be used together with --listen + + --listen=ip:port + Tell waitress to listen on an ip port combination. + + Example: + + --listen=127.0.0.1:8080 + --listen=[::1]:8080 + --listen=*:8080 + + This option may be used multiple times to listen on multiple sockets. + A wildcard for the hostname is also supported and will bind to both + IPv4/IPv6 depending on whether they are enabled or disabled. + + --[no-]ipv4 + Toggle on/off IPv4 support. + + Example: + + --no-ipv4 + + This will disable IPv4 socket support. This affects wildcard matching + when generating the list of sockets. + + --[no-]ipv6 + Toggle on/off IPv6 support. + + Example: + + --no-ipv6 + + This will turn on IPv6 socket support. This affects wildcard matching + when generating a list of sockets. + + --unix-socket=PATH + Path of Unix socket. If a socket path is specified, a Unix domain + socket is made instead of the usual inet domain socket. + + Not available on Windows. + + --unix-socket-perms=PERMS + Octal permissions to use for the Unix domain socket, default is + '600'. + + --url-scheme=STR + Default wsgi.url_scheme value, default is 'http'. + + --url-prefix=STR + The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything + except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be + the value passed minus any trailing slashes you add, and it will cause + the ``PATH_INFO`` of any request which is prefixed with this value to + be stripped of the prefix. Default is the empty string. + + --ident=STR + Server identity used in the 'Server' header in responses. Default + is 'waitress'. + +Tuning options: + + --threads=INT + Number of threads used to process application logic, default is 4. + + --backlog=INT + Connection backlog for the server. Default is 1024. + + --recv-bytes=INT + Number of bytes to request when calling socket.recv(). Default is + 8192. + + --send-bytes=INT + Number of bytes to send to socket.send(). Default is 18000. + Multiples of 9000 should avoid partly-filled TCP packets. + + --outbuf-overflow=INT + A temporary file should be created if the pending output is larger + than this. Default is 1048576 (1MB). + + --outbuf-high-watermark=INT + The app_iter will pause when pending output is larger than this value + and will resume once enough data is written to the socket to fall below + this threshold. Default is 16777216 (16MB). + + --inbuf-overflow=INT + A temporary file should be created if the pending input is larger + than this. Default is 524288 (512KB). + + --connection-limit=INT + Stop creating new channels if too many are already active. + Default is 100. + + --cleanup-interval=INT + Minimum seconds between cleaning up inactive channels. Default + is 30. See '--channel-timeout'. + + --channel-timeout=INT + Maximum number of seconds to leave inactive connections open. + Default is 120. 'Inactive' is defined as 'has received no data + from the client and has sent no data to the client'. + + --[no-]log-socket-errors + Toggle whether premature client disconnect tracebacks ought to be + logged. On by default. + + --max-request-header-size=INT + Maximum size of all request headers combined. Default is 262144 + (256KB). + + --max-request-body-size=INT + Maximum size of request body. Default is 1073741824 (1GB). + + --[no-]expose-tracebacks + Toggle whether to expose tracebacks of unhandled exceptions to the + client. Off by default. + + --asyncore-loop-timeout=INT + The timeout value in seconds passed to asyncore.loop(). Default is 1. + + --asyncore-use-poll + The use_poll argument passed to ``asyncore.loop()``. Helps overcome + open file descriptors limit. Default is False. + +""" + +RUNNER_PATTERN = re.compile( + r""" + ^ + (?P + [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* + ) + : + (?P + [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* + ) + $ + """, + re.I | re.X, +) + + +def match(obj_name): + matches = RUNNER_PATTERN.match(obj_name) + if not matches: + raise ValueError("Malformed application '{0}'".format(obj_name)) + return matches.group("module"), matches.group("object") + + +def resolve(module_name, object_name): + """Resolve a named object in a module.""" + # We cast each segments due to an issue that has been found to manifest + # in Python 2.6.6, but not 2.6.8, and may affect other revisions of Python + # 2.6 and 2.7, whereby ``__import__`` chokes if the list passed in the + # ``fromlist`` argument are unicode strings rather than 8-bit strings. + # The error triggered is "TypeError: Item in ``fromlist '' not a string". + # My guess is that this was fixed by checking against ``basestring`` + # rather than ``str`` sometime between the release of 2.6.6 and 2.6.8, + # but I've yet to go over the commits. I know, however, that the NEWS + # file makes no mention of such a change to the behaviour of + # ``__import__``. + segments = [str(segment) for segment in object_name.split(".")] + obj = __import__(module_name, fromlist=segments[:1]) + for segment in segments: + obj = getattr(obj, segment) + return obj + + +def show_help(stream, name, error=None): # pragma: no cover + if error is not None: + print("Error: {0}\n".format(error), file=stream) + print(HELP.format(name), file=stream) + + +def show_exception(stream): + exc_type, exc_value = sys.exc_info()[:2] + args = getattr(exc_value, "args", None) + print( + ("There was an exception ({0}) importing your module.\n").format( + exc_type.__name__, + ), + file=stream, + ) + if args: + print("It had these arguments: ", file=stream) + for idx, arg in enumerate(args, start=1): + print("{0}. {1}\n".format(idx, arg), file=stream) + else: + print("It had no arguments.", file=stream) + + +def run(argv=sys.argv, _serve=serve): + """Command line runner.""" + name = os.path.basename(argv[0]) + + try: + kw, args = Adjustments.parse_args(argv[1:]) + except getopt.GetoptError as exc: + show_help(sys.stderr, name, str(exc)) + return 1 + + if kw["help"]: + show_help(sys.stdout, name) + return 0 + + if len(args) != 1: + show_help(sys.stderr, name, "Specify one application only") + return 1 + + try: + module, obj_name = match(args[0]) + except ValueError as exc: + show_help(sys.stderr, name, str(exc)) + show_exception(sys.stderr) + return 1 + + # Add the current directory onto sys.path + sys.path.append(os.getcwd()) + + # Get the WSGI function. + try: + app = resolve(module, obj_name) + except ImportError: + show_help(sys.stderr, name, "Bad module '{0}'".format(module)) + show_exception(sys.stderr) + return 1 + except AttributeError: + show_help(sys.stderr, name, "Bad object name '{0}'".format(obj_name)) + show_exception(sys.stderr) + return 1 + if kw["call"]: + app = app() + + # These arguments are specific to the runner, not waitress itself. + del kw["call"], kw["help"] + + _serve(app, **kw) + return 0 diff --git a/src/waitress/server.py b/src/waitress/server.py new file mode 100644 index 0000000..ae56699 --- /dev/null +++ b/src/waitress/server.py @@ -0,0 +1,436 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import os +import os.path +import socket +import time + +from waitress import trigger +from waitress.adjustments import Adjustments +from waitress.channel import HTTPChannel +from waitress.task import ThreadedTaskDispatcher +from waitress.utilities import cleanup_unix_socket + +from waitress.compat import ( + IPPROTO_IPV6, + IPV6_V6ONLY, +) +from . import wasyncore +from .proxy_headers import proxy_headers_middleware + + +def create_server( + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + _dispatcher=None, # test shim + **kw # adjustments +): + """ + if __name__ == '__main__': + server = create_server(app) + server.run() + """ + if application is None: + raise ValueError( + 'The "app" passed to ``create_server`` was ``None``. You forgot ' + "to return a WSGI app within your application." + ) + adj = Adjustments(**kw) + + if map is None: # pragma: nocover + map = {} + + dispatcher = _dispatcher + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(adj.threads) + + if adj.unix_socket and hasattr(socket, "AF_UNIX"): + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + return UnixWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + + effective_listen = [] + last_serv = None + if not adj.sockets: + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + for sock in adj.sockets: + sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) + if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + last_serv = TcpWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + last_serv = UnixWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + # We are running a single server, so we can just return the last server, + # saves us from having to create one more object + if len(effective_listen) == 1: + # In this case we have no need to use a MultiSocketServer + return last_serv + + # Return a class that has a utility function to print out the sockets it's + # listening on, and has a .run() function. All of the TcpWSGIServers + # registered themselves in the map above. + return MultiSocketServer(map, adj, effective_listen, dispatcher) + + +# This class is only ever used if we have multiple listen sockets. It allows +# the serve() API to call .run() which starts the wasyncore loop, and catches +# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. +class MultiSocketServer(object): + asyncore = wasyncore # test shim + + def __init__( + self, map=None, adj=None, effective_listen=None, dispatcher=None, + ): + self.adj = adj + self.map = map + self.effective_listen = effective_listen + self.task_dispatcher = dispatcher + + def print_listen(self, format_str): # pragma: nocover + for l in self.effective_listen: + l = list(l) + + if ":" in l[0]: + l[0] = "[{}]".format(l[0]) + + print(format_str.format(*l)) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self.map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.close() + + def close(self): + self.task_dispatcher.shutdown() + wasyncore.close_all(self.map) + + +class BaseWSGIServer(wasyncore.dispatcher, object): + + channel_class = HTTPChannel + next_channel_cleanup = 0 + socketmod = socket # test shim + asyncore = wasyncore # test shim + + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + bind_socket=True, + **kw + ): + if adj is None: + adj = Adjustments(**kw) + + if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: + # wrap the application to deal with proxy headers + # we wrap it here because webtest subclasses the TcpWSGIServer + # directly and thus doesn't run any code that's in create_server + application = proxy_headers_middleware( + application, + trusted_proxy=adj.trusted_proxy, + trusted_proxy_count=adj.trusted_proxy_count, + trusted_proxy_headers=adj.trusted_proxy_headers, + clear_untrusted=adj.clear_untrusted_proxy_headers, + log_untrusted=adj.log_untrusted_proxy_headers, + logger=self.logger, + ) + + if map is None: + # use a nonglobal socket map by default to hopefully prevent + # conflicts with apps and libs that use the wasyncore global socket + # map ala https://github.com/Pylons/waitress/issues/63 + map = {} + if sockinfo is None: + sockinfo = adj.listen[0] + + self.sockinfo = sockinfo + self.family = sockinfo[0] + self.socktype = sockinfo[1] + self.application = application + self.adj = adj + self.trigger = trigger.trigger(map) + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(self.adj.threads) + + self.task_dispatcher = dispatcher + self.asyncore.dispatcher.__init__(self, _sock, map=map) + if _sock is None: + self.create_socket(self.family, self.socktype) + if self.family == socket.AF_INET6: # pragma: nocover + self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + + self.set_reuse_addr() + + if bind_socket: + self.bind_server_socket() + + self.effective_host, self.effective_port = self.getsockname() + self.server_name = self.get_server_name(self.effective_host) + self.active_channels = {} + if _start: + self.accept_connections() + + def bind_server_socket(self): + raise NotImplementedError # pragma: no cover + + def get_server_name(self, ip): + """Given an IP or hostname, try to determine the server name.""" + + if not ip: + raise ValueError("Requires an IP to get the server name") + + server_name = str(ip) + + # If we are bound to all IP's, just return the current hostname, only + # fall-back to "localhost" if we fail to get the hostname + if server_name == "0.0.0.0" or server_name == "::": + try: + return str(self.socketmod.gethostname()) + except (socket.error, UnicodeDecodeError): # pragma: no cover + # We also deal with UnicodeDecodeError in case of Windows with + # non-ascii hostname + return "localhost" + + # Now let's try and convert the IP address to a proper hostname + try: + server_name = self.socketmod.gethostbyaddr(server_name)[0] + except (socket.error, UnicodeDecodeError): # pragma: no cover + # We also deal with UnicodeDecodeError in case of Windows with + # non-ascii hostname + pass + + # If it contains an IPv6 literal, make sure to surround it with + # brackets + if ":" in server_name and "[" not in server_name: + server_name = "[{}]".format(server_name) + + return server_name + + def getsockname(self): + raise NotImplementedError # pragma: no cover + + def accept_connections(self): + self.accepting = True + self.socket.listen(self.adj.backlog) # Get around asyncore NT limit + + def add_task(self, task): + self.task_dispatcher.add_task(task) + + def readable(self): + now = time.time() + if now >= self.next_channel_cleanup: + self.next_channel_cleanup = now + self.adj.cleanup_interval + self.maintenance(now) + return self.accepting and len(self._map) < self.adj.connection_limit + + def writable(self): + return False + + def handle_read(self): + pass + + def handle_connect(self): + pass + + def handle_accept(self): + try: + v = self.accept() + if v is None: + return + conn, addr = v + except socket.error: + # Linux: On rare occasions we get a bogus socket back from + # accept. socketmodule.c:makesockaddr complains that the + # address family is unknown. We don't want the whole server + # to shut down because of this. + if self.adj.log_socket_errors: + self.logger.warning("server accept() threw an exception", exc_info=True) + return + self.set_socket_options(conn) + addr = self.fix_addr(addr) + self.channel_class(self, conn, addr, self.adj, map=self._map) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self._map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.task_dispatcher.shutdown() + + def pull_trigger(self): + self.trigger.pull_trigger() + + def set_socket_options(self, conn): + pass + + def fix_addr(self, addr): + return addr + + def maintenance(self, now): + """ + Closes channels that have not had any activity in a while. + + The timeout is configured through adj.channel_timeout (seconds). + """ + cutoff = now - self.adj.channel_timeout + for channel in self.active_channels.values(): + if (not channel.requests) and channel.last_activity < cutoff: + channel.will_close = True + + def print_listen(self, format_str): # pragma: nocover + print(format_str.format(self.effective_host, self.effective_port)) + + def close(self): + self.trigger.close() + return wasyncore.dispatcher.close(self) + + +class TcpWSGIServer(BaseWSGIServer): + def bind_server_socket(self): + (_, _, _, sockaddr) = self.sockinfo + self.bind(sockaddr) + + def getsockname(self): + try: + return self.socketmod.getnameinfo( + self.socket.getsockname(), self.socketmod.NI_NUMERICSERV + ) + except: # pragma: no cover + # This only happens on Linux because a DNS issue is considered a + # temporary failure that will raise (even when NI_NAMEREQD is not + # set). Instead we try again, but this time we just ask for the + # numerichost and the numericserv (port) and return those. It is + # better than nothing. + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, + ) + + def set_socket_options(self, conn): + for (level, optname, value) in self.adj.socket_options: + conn.setsockopt(level, optname, value) + + +if hasattr(socket, "AF_UNIX"): + + class UnixWSGIServer(BaseWSGIServer): + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw + ): + if sockinfo is None: + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + + super(UnixWSGIServer, self).__init__( + application, + map=map, + _start=_start, + _sock=_sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + **kw + ) + + def bind_server_socket(self): + cleanup_unix_socket(self.adj.unix_socket) + self.bind(self.adj.unix_socket) + if os.path.exists(self.adj.unix_socket): + os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) + + def getsockname(self): + return ("unix", self.socket.getsockname()) + + def fix_addr(self, addr): + return ("localhost", None) + + def get_server_name(self, ip): + return "localhost" + + +# Compatibility alias. +WSGIServer = TcpWSGIServer diff --git a/src/waitress/task.py b/src/waitress/task.py new file mode 100644 index 0000000..8e7ab18 --- /dev/null +++ b/src/waitress/task.py @@ -0,0 +1,570 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import socket +import sys +import threading +import time +from collections import deque + +from .buffers import ReadOnlyFileBasedBuffer +from .compat import reraise, tobytes +from .utilities import build_http_date, logger, queue_logger + +rename_headers = { # or keep them without the HTTP_ prefix added + "CONTENT_LENGTH": "CONTENT_LENGTH", + "CONTENT_TYPE": "CONTENT_TYPE", +} + +hop_by_hop = frozenset( + ( + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ) +) + + +class ThreadedTaskDispatcher(object): + """A Task Dispatcher that creates a thread for each task. + """ + + stop_count = 0 # Number of threads that will stop soon. + active_count = 0 # Number of currently active threads + logger = logger + queue_logger = queue_logger + + def __init__(self): + self.threads = set() + self.queue = deque() + self.lock = threading.Lock() + self.queue_cv = threading.Condition(self.lock) + self.thread_exit_cv = threading.Condition(self.lock) + + def start_new_thread(self, target, args): + t = threading.Thread(target=target, name="waitress", args=args) + t.daemon = True + t.start() + + def handler_thread(self, thread_no): + while True: + with self.lock: + while not self.queue and self.stop_count == 0: + # Mark ourselves as idle before waiting to be + # woken up, then we will once again be active + self.active_count -= 1 + self.queue_cv.wait() + self.active_count += 1 + + if self.stop_count > 0: + self.active_count -= 1 + self.stop_count -= 1 + self.threads.discard(thread_no) + self.thread_exit_cv.notify() + break + + task = self.queue.popleft() + try: + task.service() + except BaseException: + self.logger.exception("Exception when servicing %r", task) + + def set_thread_count(self, count): + with self.lock: + threads = self.threads + thread_no = 0 + running = len(threads) - self.stop_count + while running < count: + # Start threads. + while thread_no in threads: + thread_no = thread_no + 1 + threads.add(thread_no) + running += 1 + self.start_new_thread(self.handler_thread, (thread_no,)) + self.active_count += 1 + thread_no = thread_no + 1 + if running > count: + # Stop threads. + self.stop_count += running - count + self.queue_cv.notify_all() + + def add_task(self, task): + with self.lock: + self.queue.append(task) + self.queue_cv.notify() + queue_size = len(self.queue) + idle_threads = len(self.threads) - self.stop_count - self.active_count + if queue_size > idle_threads: + self.queue_logger.warning( + "Task queue depth is %d", queue_size - idle_threads + ) + + def shutdown(self, cancel_pending=True, timeout=5): + self.set_thread_count(0) + # Ensure the threads shut down. + threads = self.threads + expiration = time.time() + timeout + with self.lock: + while threads: + if time.time() >= expiration: + self.logger.warning("%d thread(s) still running", len(threads)) + break + self.thread_exit_cv.wait(0.1) + if cancel_pending: + # Cancel remaining tasks. + queue = self.queue + if len(queue) > 0: + self.logger.warning("Canceling %d pending task(s)", len(queue)) + while queue: + task = queue.popleft() + task.cancel() + self.queue_cv.notify_all() + return True + return False + + +class Task(object): + close_on_finish = False + status = "200 OK" + wrote_header = False + start_time = 0 + content_length = None + content_bytes_written = 0 + logged_write_excess = False + logged_write_no_body = False + complete = False + chunked_response = False + logger = logger + + def __init__(self, channel, request): + self.channel = channel + self.request = request + self.response_headers = [] + version = request.version + if version not in ("1.0", "1.1"): + # fall back to a version we support. + version = "1.0" + self.version = version + + def service(self): + try: + try: + self.start() + self.execute() + self.finish() + except socket.error: + self.close_on_finish = True + if self.channel.adj.log_socket_errors: + raise + finally: + pass + + @property + def has_body(self): + return not ( + self.status.startswith("1") + or self.status.startswith("204") + or self.status.startswith("304") + ) + + def build_response_header(self): + version = self.version + # Figure out whether the connection should be closed. + connection = self.request.headers.get("CONNECTION", "").lower() + response_headers = [] + content_length_header = None + date_header = None + server_header = None + connection_close_header = None + + for (headername, headerval) in self.response_headers: + headername = "-".join([x.capitalize() for x in headername.split("-")]) + + if headername == "Content-Length": + if self.has_body: + content_length_header = headerval + else: + continue # pragma: no cover + + if headername == "Date": + date_header = headerval + + if headername == "Server": + server_header = headerval + + if headername == "Connection": + connection_close_header = headerval.lower() + # replace with properly capitalized version + response_headers.append((headername, headerval)) + + if ( + content_length_header is None + and self.content_length is not None + and self.has_body + ): + content_length_header = str(self.content_length) + response_headers.append(("Content-Length", content_length_header)) + + def close_on_finish(): + if connection_close_header is None: + response_headers.append(("Connection", "close")) + self.close_on_finish = True + + if version == "1.0": + if connection == "keep-alive": + if not content_length_header: + close_on_finish() + else: + response_headers.append(("Connection", "Keep-Alive")) + else: + close_on_finish() + + elif version == "1.1": + if connection == "close": + close_on_finish() + + if not content_length_header: + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + + if self.has_body: + response_headers.append(("Transfer-Encoding", "chunked")) + self.chunked_response = True + + if not self.close_on_finish: + close_on_finish() + + # under HTTP 1.1 keep-alive is default, no need to set the header + else: + raise AssertionError("neither HTTP/1.0 or HTTP/1.1") + + # Set the Server and Date field, if not yet specified. This is needed + # if the server is used as a proxy. + ident = self.channel.server.adj.ident + + if not server_header: + if ident: + response_headers.append(("Server", ident)) + else: + response_headers.append(("Via", ident or "waitress")) + + if not date_header: + response_headers.append(("Date", build_http_date(self.start_time))) + + self.response_headers = response_headers + + first_line = "HTTP/%s %s" % (self.version, self.status) + # NB: sorting headers needs to preserve same-named-header order + # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; + # rely on stable sort to keep relative position of same-named headers + next_lines = [ + "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) + ] + lines = [first_line] + next_lines + res = "%s\r\n\r\n" % "\r\n".join(lines) + + return tobytes(res) + + def remove_content_length_header(self): + response_headers = [] + + for header_name, header_value in self.response_headers: + if header_name.lower() == "content-length": + continue # pragma: nocover + response_headers.append((header_name, header_value)) + + self.response_headers = response_headers + + def start(self): + self.start_time = time.time() + + def finish(self): + if not self.wrote_header: + self.write(b"") + if self.chunked_response: + # not self.write, it will chunk it! + self.channel.write_soon(b"0\r\n\r\n") + + def write(self, data): + if not self.complete: + raise RuntimeError("start_response was not called before body written") + channel = self.channel + if not self.wrote_header: + rh = self.build_response_header() + channel.write_soon(rh) + self.wrote_header = True + + if data and self.has_body: + towrite = data + cl = self.content_length + if self.chunked_response: + # use chunked encoding response + towrite = tobytes(hex(len(data))[2:].upper()) + b"\r\n" + towrite += data + b"\r\n" + elif cl is not None: + towrite = data[: cl - self.content_bytes_written] + self.content_bytes_written += len(towrite) + if towrite != data and not self.logged_write_excess: + self.logger.warning( + "application-written content exceeded the number of " + "bytes specified by Content-Length header (%s)" % cl + ) + self.logged_write_excess = True + if towrite: + channel.write_soon(towrite) + elif data: + # Cheat, and tell the application we have written all of the bytes, + # even though the response shouldn't have a body and we are + # ignoring it entirely. + self.content_bytes_written += len(data) + + if not self.logged_write_no_body: + self.logger.warning( + "application-written content was ignored due to HTTP " + "response that may not contain a message-body: (%s)" % self.status + ) + self.logged_write_no_body = True + + +class ErrorTask(Task): + """ An error task produces an error response + """ + + complete = True + + def execute(self): + e = self.request.error + status, headers, body = e.to_response() + self.status = status + self.response_headers.extend(headers) + # We need to explicitly tell the remote client we are closing the + # connection, because self.close_on_finish is set, and we are going to + # slam the door in the clients face. + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + self.content_length = len(body) + self.write(tobytes(body)) + + +class WSGITask(Task): + """A WSGI task produces a response from a WSGI application. + """ + + environ = None + + def execute(self): + environ = self.get_environment() + + def start_response(status, headers, exc_info=None): + if self.complete and not exc_info: + raise AssertionError( + "start_response called a second time without providing exc_info." + ) + if exc_info: + try: + if self.wrote_header: + # higher levels will catch and handle raised exception: + # 1. "service" method in task.py + # 2. "service" method in channel.py + # 3. "handler_thread" method in task.py + reraise(exc_info[0], exc_info[1], exc_info[2]) + else: + # As per WSGI spec existing headers must be cleared + self.response_headers = [] + finally: + exc_info = None + + self.complete = True + + if not status.__class__ is str: + raise AssertionError("status %s is not a string" % status) + if "\n" in status or "\r" in status: + raise ValueError( + "carriage return/line feed character present in status" + ) + + self.status = status + + # Prepare the headers for output + for k, v in headers: + if not k.__class__ is str: + raise AssertionError( + "Header name %r is not a string in %r" % (k, (k, v)) + ) + if not v.__class__ is str: + raise AssertionError( + "Header value %r is not a string in %r" % (v, (k, v)) + ) + + if "\n" in v or "\r" in v: + raise ValueError( + "carriage return/line feed character present in header value" + ) + if "\n" in k or "\r" in k: + raise ValueError( + "carriage return/line feed character present in header name" + ) + + kl = k.lower() + if kl == "content-length": + self.content_length = int(v) + elif kl in hop_by_hop: + raise AssertionError( + '%s is a "hop-by-hop" header; it cannot be used by ' + "a WSGI application (see PEP 3333)" % k + ) + + self.response_headers.extend(headers) + + # Return a method used to write the response data. + return self.write + + # Call the application to handle the request and write a response + app_iter = self.channel.server.application(environ, start_response) + + can_close_app_iter = True + try: + if app_iter.__class__ is ReadOnlyFileBasedBuffer: + cl = self.content_length + size = app_iter.prepare(cl) + if size: + if cl != size: + if cl is not None: + self.remove_content_length_header() + self.content_length = size + self.write(b"") # generate headers + # if the write_soon below succeeds then the channel will + # take over closing the underlying file via the channel's + # _flush_some or handle_close so we intentionally avoid + # calling close in the finally block + self.channel.write_soon(app_iter) + can_close_app_iter = False + return + + first_chunk_len = None + for chunk in app_iter: + if first_chunk_len is None: + first_chunk_len = len(chunk) + # Set a Content-Length header if one is not supplied. + # start_response may not have been called until first + # iteration as per PEP, so we must reinterrogate + # self.content_length here + if self.content_length is None: + app_iter_len = None + if hasattr(app_iter, "__len__"): + app_iter_len = len(app_iter) + if app_iter_len == 1: + self.content_length = first_chunk_len + # transmit headers only after first iteration of the iterable + # that returns a non-empty bytestring (PEP 3333) + if chunk: + self.write(chunk) + + cl = self.content_length + if cl is not None: + if self.content_bytes_written != cl: + # close the connection so the client isn't sitting around + # waiting for more data when there are too few bytes + # to service content-length + self.close_on_finish = True + if self.request.command != "HEAD": + self.logger.warning( + "application returned too few bytes (%s) " + "for specified Content-Length (%s) via app_iter" + % (self.content_bytes_written, cl), + ) + finally: + if can_close_app_iter and hasattr(app_iter, "close"): + app_iter.close() + + def get_environment(self): + """Returns a WSGI environment.""" + environ = self.environ + if environ is not None: + # Return the cached copy. + return environ + + request = self.request + path = request.path + channel = self.channel + server = channel.server + url_prefix = server.adj.url_prefix + + if path.startswith("/"): + # strip extra slashes at the beginning of a path that starts + # with any number of slashes + path = "/" + path.lstrip("/") + + if url_prefix: + # NB: url_prefix is guaranteed by the configuration machinery to + # be either the empty string or a string that starts with a single + # slash and ends without any slashes + if path == url_prefix: + # if the path is the same as the url prefix, the SCRIPT_NAME + # should be the url_prefix and PATH_INFO should be empty + path = "" + else: + # if the path starts with the url prefix plus a slash, + # the SCRIPT_NAME should be the url_prefix and PATH_INFO should + # the value of path from the slash until its end + url_prefix_with_trailing_slash = url_prefix + "/" + if path.startswith(url_prefix_with_trailing_slash): + path = path[len(url_prefix) :] + + environ = { + "REMOTE_ADDR": channel.addr[0], + # Nah, we aren't actually going to look up the reverse DNS for + # REMOTE_ADDR, but we will happily set this environment variable + # for the WSGI application. Spec says we can just set this to + # REMOTE_ADDR, so we do. + "REMOTE_HOST": channel.addr[0], + # try and set the REMOTE_PORT to something useful, but maybe None + "REMOTE_PORT": str(channel.addr[1]), + "REQUEST_METHOD": request.command.upper(), + "SERVER_PORT": str(server.effective_port), + "SERVER_NAME": server.server_name, + "SERVER_SOFTWARE": server.adj.ident, + "SERVER_PROTOCOL": "HTTP/%s" % self.version, + "SCRIPT_NAME": url_prefix, + "PATH_INFO": path, + "QUERY_STRING": request.query, + "wsgi.url_scheme": request.url_scheme, + # the following environment variables are required by the WSGI spec + "wsgi.version": (1, 0), + # apps should use the logging module + "wsgi.errors": sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "wsgi.input": request.get_body_stream(), + "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, + "wsgi.input_terminated": True, # wsgi.input is EOF terminated + } + + for key, value in dict(request.headers).items(): + value = value.strip() + mykey = rename_headers.get(key, None) + if mykey is None: + mykey = "HTTP_" + key + if mykey not in environ: + environ[mykey] = value + + # cache the environ for this request + self.environ = environ + return environ diff --git a/src/waitress/trigger.py b/src/waitress/trigger.py new file mode 100644 index 0000000..6a57c12 --- /dev/null +++ b/src/waitress/trigger.py @@ -0,0 +1,203 @@ +############################################################################## +# +# Copyright (c) 2001-2005 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE +# +############################################################################## + +import os +import socket +import errno +import threading + +from . import wasyncore + +# Wake up a call to select() running in the main thread. +# +# This is useful in a context where you are using Medusa's I/O +# subsystem to deliver data, but the data is generated by another +# thread. Normally, if Medusa is in the middle of a call to +# select(), new output data generated by another thread will have +# to sit until the call to select() either times out or returns. +# If the trigger is 'pulled' by another thread, it should immediately +# generate a READ event on the trigger object, which will force the +# select() invocation to return. +# +# A common use for this facility: letting Medusa manage I/O for a +# large number of connections; but routing each request through a +# thread chosen from a fixed-size thread pool. When a thread is +# acquired, a transaction is performed, but output data is +# accumulated into buffers that will be emptied more efficiently +# by Medusa. [picture a server that can process database queries +# rapidly, but doesn't want to tie up threads waiting to send data +# to low-bandwidth connections] +# +# The other major feature provided by this class is the ability to +# move work back into the main thread: if you call pull_trigger() +# with a thunk argument, when select() wakes up and receives the +# event it will call your thunk from within that thread. The main +# purpose of this is to remove the need to wrap thread locks around +# Medusa's data structures, which normally do not need them. [To see +# why this is true, imagine this scenario: A thread tries to push some +# new data onto a channel's outgoing data queue at the same time that +# the main thread is trying to remove some] + + +class _triggerbase(object): + """OS-independent base class for OS-dependent trigger class.""" + + kind = None # subclass must set to "pipe" or "loopback"; used by repr + + def __init__(self): + self._closed = False + + # `lock` protects the `thunks` list from being traversed and + # appended to simultaneously. + self.lock = threading.Lock() + + # List of no-argument callbacks to invoke when the trigger is + # pulled. These run in the thread running the wasyncore mainloop, + # regardless of which thread pulls the trigger. + self.thunks = [] + + def readable(self): + return True + + def writable(self): + return False + + def handle_connect(self): + pass + + def handle_close(self): + self.close() + + # Override the wasyncore close() method, because it doesn't know about + # (so can't close) all the gimmicks we have open. Subclass must + # supply a _close() method to do platform-specific closing work. _close() + # will be called iff we're not already closed. + def close(self): + if not self._closed: + self._closed = True + self.del_channel() + self._close() # subclass does OS-specific stuff + + def pull_trigger(self, thunk=None): + if thunk: + with self.lock: + self.thunks.append(thunk) + self._physical_pull() + + def handle_read(self): + try: + self.recv(8192) + except (OSError, socket.error): + return + with self.lock: + for thunk in self.thunks: + try: + thunk() + except: + nil, t, v, tbinfo = wasyncore.compact_traceback() + self.log_info( + "exception in trigger thunk: (%s:%s %s)" % (t, v, tbinfo) + ) + self.thunks = [] + + +if os.name == "posix": + + class trigger(_triggerbase, wasyncore.file_dispatcher): + kind = "pipe" + + def __init__(self, map): + _triggerbase.__init__(self) + r, self.trigger = self._fds = os.pipe() + wasyncore.file_dispatcher.__init__(self, r, map=map) + + def _close(self): + for fd in self._fds: + os.close(fd) + self._fds = [] + wasyncore.file_dispatcher.close(self) + + def _physical_pull(self): + os.write(self.trigger, b"x") + + +else: # pragma: no cover + # Windows version; uses just sockets, because a pipe isn't select'able + # on Windows. + + class trigger(_triggerbase, wasyncore.dispatcher): + kind = "loopback" + + def __init__(self, map): + _triggerbase.__init__(self) + + # Get a pair of connected sockets. The trigger is the 'w' + # end of the pair, which is connected to 'r'. 'r' is put + # in the wasyncore socket map. "pulling the trigger" then + # means writing something on w, which will wake up r. + + w = socket.socket() + # Disable buffering -- pulling the trigger sends 1 byte, + # and we want that sent immediately, to wake up wasyncore's + # select() ASAP. + w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + count = 0 + while True: + count += 1 + # Bind to a local port; for efficiency, let the OS pick + # a free port for us. + # Unfortunately, stress tests showed that we may not + # be able to connect to that port ("Address already in + # use") despite that the OS picked it. This appears + # to be a race bug in the Windows socket implementation. + # So we loop until a connect() succeeds (almost always + # on the first try). See the long thread at + # http://mail.zope.org/pipermail/zope/2005-July/160433.html + # for hideous details. + a = socket.socket() + a.bind(("127.0.0.1", 0)) + connect_address = a.getsockname() # assigned (host, port) pair + a.listen(1) + try: + w.connect(connect_address) + break # success + except socket.error as detail: + if detail[0] != errno.WSAEADDRINUSE: + # "Address already in use" is the only error + # I've seen on two WinXP Pro SP2 boxes, under + # Pythons 2.3.5 and 2.4.1. + raise + # (10048, 'Address already in use') + # assert count <= 2 # never triggered in Tim's tests + if count >= 10: # I've never seen it go above 2 + a.close() + w.close() + raise RuntimeError("Cannot bind trigger!") + # Close `a` and try again. Note: I originally put a short + # sleep() here, but it didn't appear to help or hurt. + a.close() + + r, addr = a.accept() # r becomes wasyncore's (self.)socket + a.close() + self.trigger = w + wasyncore.dispatcher.__init__(self, r, map=map) + + def _close(self): + # self.socket is r, and self.trigger is w, from __init__ + self.socket.close() + self.trigger.close() + + def _physical_pull(self): + self.trigger.send(b"x") diff --git a/src/waitress/utilities.py b/src/waitress/utilities.py new file mode 100644 index 0000000..556bed2 --- /dev/null +++ b/src/waitress/utilities.py @@ -0,0 +1,320 @@ +############################################################################## +# +# Copyright (c) 2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Utility functions +""" + +import calendar +import errno +import logging +import os +import re +import stat +import time + +from .rfc7230 import OBS_TEXT, VCHAR + +logger = logging.getLogger("waitress") +queue_logger = logging.getLogger("waitress.queue") + + +def find_double_newline(s): + """Returns the position just after a double newline in the given string.""" + pos = s.find(b"\r\n\r\n") + + if pos >= 0: + pos += 4 + + return pos + + +def concat(*args): + return "".join(args) + + +def join(seq, field=" "): + return field.join(seq) + + +def group(s): + return "(" + s + ")" + + +short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] +long_days = [ + "sunday", + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", +] + +short_day_reg = group(join(short_days, "|")) +long_day_reg = group(join(long_days, "|")) + +daymap = {} + +for i in range(7): + daymap[short_days[i]] = i + daymap[long_days[i]] = i + +hms_reg = join(3 * [group("[0-9][0-9]")], ":") + +months = [ + "jan", + "feb", + "mar", + "apr", + "may", + "jun", + "jul", + "aug", + "sep", + "oct", + "nov", + "dec", +] + +monmap = {} + +for i in range(12): + monmap[months[i]] = i + 1 + +months_reg = group(join(months, "|")) + +# From draft-ietf-http-v11-spec-07.txt/3.3.1 +# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 +# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 +# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format + +# rfc822 format +rfc822_date = join( + [ + concat(short_day_reg, ","), # day + group("[0-9][0-9]?"), # date + months_reg, # month + group("[0-9]+"), # year + hms_reg, # hour minute second + "gmt", + ], + " ", +) + +rfc822_reg = re.compile(rfc822_date) + + +def unpack_rfc822(m): + g = m.group + + return ( + int(g(4)), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# rfc850 format +rfc850_date = join( + [ + concat(long_day_reg, ","), + join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"), + hms_reg, + "gmt", + ], + " ", +) + +rfc850_reg = re.compile(rfc850_date) +# they actually unpack the same way +def unpack_rfc850(m): + g = m.group + yr = g(4) + + if len(yr) == 2: + yr = "19" + yr + + return ( + int(yr), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# parsdate.parsedate - ~700/sec. +# parse_http_date - ~1333/sec. + +weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] +monthname = [ + None, + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", +] + + +def build_http_date(when): + year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) + + return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + weekdayname[wd], + day, + monthname[month], + year, + hh, + mm, + ss, + ) + + +def parse_http_date(d): + d = d.lower() + m = rfc850_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc850(m))) + else: + m = rfc822_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc822(m))) + else: + return 0 + + return retval + + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +vchar_re = VCHAR + +# RFC 7230 Section 3.2.6 "Field Value Components": +# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE +# qdtext = HTAB / SP /%x21 / %x23-5B / %x5D-7E / obs-text +# obs-text = %x80-FF +# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) +obs_text_re = OBS_TEXT + +# The '\\' between \x5b and \x5d is needed to escape \x5d (']') +qdtext_re = "[\t \x21\x23-\x5b\\\x5d-\x7e" + obs_text_re + "]" + +quoted_pair_re = r"\\" + "([\t " + vchar_re + obs_text_re + "])" +quoted_string_re = '"(?:(?:' + qdtext_re + ")|(?:" + quoted_pair_re + '))*"' + +quoted_string = re.compile(quoted_string_re) +quoted_pair = re.compile(quoted_pair_re) + + +def undquote(value): + if value.startswith('"') and value.endswith('"'): + # So it claims to be DQUOTE'ed, let's validate that + matches = quoted_string.match(value) + + if matches and matches.end() == len(value): + # Remove the DQUOTE's from the value + value = value[1:-1] + + # Remove all backslashes that are followed by a valid vchar or + # obs-text + value = quoted_pair.sub(r"\1", value) + + return value + elif not value.startswith('"') and not value.endswith('"'): + return value + + raise ValueError("Invalid quoting in value") + + +def cleanup_unix_socket(path): + try: + st = os.stat(path) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise # pragma: no cover + else: + if stat.S_ISSOCK(st.st_mode): + try: + os.remove(path) + except OSError: # pragma: no cover + # avoid race condition error during tests + pass + + +class Error(object): + code = 500 + reason = "Internal Server Error" + + def __init__(self, body): + self.body = body + + def to_response(self): + status = "%s %s" % (self.code, self.reason) + body = "%s\r\n\r\n%s" % (self.reason, self.body) + tag = "\r\n\r\n(generated by waitress)" + body = body + tag + headers = [("Content-Type", "text/plain")] + + return status, headers, body + + def wsgi_response(self, environ, start_response): + status, headers, body = self.to_response() + start_response(status, headers) + yield body + + +class BadRequest(Error): + code = 400 + reason = "Bad Request" + + +class RequestHeaderFieldsTooLarge(BadRequest): + code = 431 + reason = "Request Header Fields Too Large" + + +class RequestEntityTooLarge(BadRequest): + code = 413 + reason = "Request Entity Too Large" + + +class InternalServerError(Error): + code = 500 + reason = "Internal Server Error" + + +class ServerNotImplemented(Error): + code = 501 + reason = "Not Implemented" diff --git a/src/waitress/wasyncore.py b/src/waitress/wasyncore.py new file mode 100644 index 0000000..09bcafa --- /dev/null +++ b/src/waitress/wasyncore.py @@ -0,0 +1,693 @@ +# -*- Mode: Python -*- +# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp +# Author: Sam Rushing + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +"""Basic infrastructure for asynchronous socket service clients and servers. + +There are only two ways to have a program on a single processor do "more +than one thing at a time". Multi-threaded programming is the simplest and +most popular way to do it, but there is another very different technique, +that lets you have nearly all the advantages of multi-threading, without +actually using multiple threads. it's really only practical if your program +is largely I/O bound. If your program is CPU bound, then pre-emptive +scheduled threads are probably what you really need. Network servers are +rarely CPU-bound, however. + +If your operating system supports the select() system call in its I/O +library (and nearly all do), then you can use it to juggle multiple +communication channels at once; doing other work while your I/O is taking +place in the "background." Although this strategy can seem strange and +complex, especially at first, it is in many ways easier to understand and +control than multi-threaded programming. The module documented here solves +many of the difficult problems for you, making the task of building +sophisticated high-performance network servers and clients a snap. + +NB: this is a fork of asyncore from the stdlib that we've (the waitress +developers) named 'wasyncore' to ensure forward compatibility, as asyncore +in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore +nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. +""" + +from . import compat +from . import utilities + +import logging +import select +import socket +import sys +import time +import warnings + +import os +from errno import ( + EALREADY, + EINPROGRESS, + EWOULDBLOCK, + ECONNRESET, + EINVAL, + ENOTCONN, + ESHUTDOWN, + EISCONN, + EBADF, + ECONNABORTED, + EPIPE, + EAGAIN, + EINTR, + errorcode, +) + +_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) + +try: + socket_map +except NameError: + socket_map = {} + + +def _strerror(err): + try: + return os.strerror(err) + except (TypeError, ValueError, OverflowError, NameError): + return "Unknown error %s" % err + + +class ExitNow(Exception): + pass + + +_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) + + +def read(obj): + try: + obj.handle_read_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def write(obj): + try: + obj.handle_write_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def _exception(obj): + try: + obj.handle_expt_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def readwrite(obj, flags): + try: + if flags & select.POLLIN: + obj.handle_read_event() + if flags & select.POLLOUT: + obj.handle_write_event() + if flags & select.POLLPRI: + obj.handle_expt_event() + if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + obj.handle_close() + except socket.error as e: + if e.args[0] not in _DISCONNECTED: + obj.handle_error() + else: + obj.handle_close() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def poll(timeout=0.0, map=None): + if map is None: # pragma: no cover + map = socket_map + if map: + r = [] + w = [] + e = [] + for fd, obj in list(map.items()): # list() call FBO py3 + is_r = obj.readable() + is_w = obj.writable() + if is_r: + r.append(fd) + # accepting sockets should not be writable + if is_w and not obj.accepting: + w.append(fd) + if is_r or is_w: + e.append(fd) + if [] == r == w == e: + time.sleep(timeout) + return + + try: + r, w, e = select.select(r, w, e, timeout) + except select.error as err: + if err.args[0] != EINTR: + raise + else: + return + + for fd in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + read(obj) + + for fd in w: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + write(obj) + + for fd in e: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + _exception(obj) + + +def poll2(timeout=0.0, map=None): + # Use the poll() support added to the select module in Python 2.0 + if map is None: # pragma: no cover + map = socket_map + if timeout is not None: + # timeout is in milliseconds + timeout = int(timeout * 1000) + pollster = select.poll() + if map: + for fd, obj in list(map.items()): + flags = 0 + if obj.readable(): + flags |= select.POLLIN | select.POLLPRI + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: + flags |= select.POLLOUT + if flags: + pollster.register(fd, flags) + + try: + r = pollster.poll(timeout) + except select.error as err: + if err.args[0] != EINTR: + raise + r = [] + + for fd, flags in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + readwrite(obj, flags) + + +poll3 = poll2 # Alias for backward compatibility + + +def loop(timeout=30.0, use_poll=False, map=None, count=None): + if map is None: # pragma: no cover + map = socket_map + + if use_poll and hasattr(select, "poll"): + poll_fun = poll2 + else: + poll_fun = poll + + if count is None: # pragma: no cover + while map: + poll_fun(timeout, map) + + else: + while map and count > 0: + poll_fun(timeout, map) + count = count - 1 + + +def compact_traceback(): + t, v, tb = sys.exc_info() + tbinfo = [] + if not tb: # pragma: no cover + raise AssertionError("traceback does not exist") + while tb: + tbinfo.append( + ( + tb.tb_frame.f_code.co_filename, + tb.tb_frame.f_code.co_name, + str(tb.tb_lineno), + ) + ) + tb = tb.tb_next + + # just to be safe + del tb + + file, function, line = tbinfo[-1] + info = " ".join(["[%s|%s|%s]" % x for x in tbinfo]) + return (file, function, line), t, v, info + + +class dispatcher: + + debug = False + connected = False + accepting = False + connecting = False + closing = False + addr = None + ignore_log_types = frozenset({"warning"}) + logger = utilities.logger + compact_traceback = staticmethod(compact_traceback) # for testing + + def __init__(self, sock=None, map=None): + if map is None: # pragma: no cover + self._map = socket_map + else: + self._map = map + + self._fileno = None + + if sock: + # Set to nonblocking just to make sure for cases where we + # get a socket from a blocking source. + sock.setblocking(0) + self.set_socket(sock, map) + self.connected = True + # The constructor no longer requires that the socket + # passed be connected. + try: + self.addr = sock.getpeername() + except socket.error as err: + if err.args[0] in (ENOTCONN, EINVAL): + # To handle the case where we got an unconnected + # socket. + self.connected = False + else: + # The socket is broken in some unknown way, alert + # the user and remove it from the map (to prevent + # polling of broken sockets). + self.del_channel(map) + raise + else: + self.socket = None + + def __repr__(self): + status = [self.__class__.__module__ + "." + compat.qualname(self.__class__)] + if self.accepting and self.addr: + status.append("listening") + elif self.connected: + status.append("connected") + if self.addr is not None: + try: + status.append("%s:%d" % self.addr) + except TypeError: # pragma: no cover + status.append(repr(self.addr)) + return "<%s at %#x>" % (" ".join(status), id(self)) + + __str__ = __repr__ + + def add_channel(self, map=None): + # self.log_info('adding channel %s' % self) + if map is None: + map = self._map + map[self._fileno] = self + + def del_channel(self, map=None): + fd = self._fileno + if map is None: + map = self._map + if fd in map: + # self.log_info('closing channel %d:%s' % (fd, self)) + del map[fd] + self._fileno = None + + def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): + self.family_and_type = family, type + sock = socket.socket(family, type) + sock.setblocking(0) + self.set_socket(sock) + + def set_socket(self, sock, map=None): + self.socket = sock + self._fileno = sock.fileno() + self.add_channel(map) + + def set_reuse_addr(self): + # try to re-use a server port if possible + try: + self.socket.setsockopt( + socket.SOL_SOCKET, + socket.SO_REUSEADDR, + self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, + ) + except socket.error: + pass + + # ================================================== + # predicates for select() + # these are used as filters for the lists of sockets + # to pass to select(). + # ================================================== + + def readable(self): + return True + + def writable(self): + return True + + # ================================================== + # socket object methods. + # ================================================== + + def listen(self, num): + self.accepting = True + if os.name == "nt" and num > 5: # pragma: no cover + num = 5 + return self.socket.listen(num) + + def bind(self, addr): + self.addr = addr + return self.socket.bind(addr) + + def connect(self, address): + self.connected = False + self.connecting = True + err = self.socket.connect_ex(address) + if ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) + or err == EINVAL + and os.name == "nt" + ): # pragma: no cover + self.addr = address + return + if err in (0, EISCONN): + self.addr = address + self.handle_connect_event() + else: + raise socket.error(err, errorcode[err]) + + def accept(self): + # XXX can return either an address pair or None + try: + conn, addr = self.socket.accept() + except TypeError: + return None + except socket.error as why: + if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): + return None + else: + raise + else: + return conn, addr + + def send(self, data): + try: + result = self.socket.send(data) + return result + except socket.error as why: + if why.args[0] == EWOULDBLOCK: + return 0 + elif why.args[0] in _DISCONNECTED: + self.handle_close() + return 0 + else: + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.handle_close() + return b"" + else: + return data + except socket.error as why: + # winsock sometimes raises ENOTCONN + if why.args[0] in _DISCONNECTED: + self.handle_close() + return b"" + else: + raise + + def close(self): + self.connected = False + self.accepting = False + self.connecting = False + self.del_channel() + if self.socket is not None: + try: + self.socket.close() + except socket.error as why: + if why.args[0] not in (ENOTCONN, EBADF): + raise + + # log and log_info may be overridden to provide more sophisticated + # logging and warning methods. In general, log is for 'hit' logging + # and 'log_info' is for informational, warning and error logging. + + def log(self, message): + self.logger.log(logging.DEBUG, message) + + def log_info(self, message, type="info"): + severity = { + "info": logging.INFO, + "warning": logging.WARN, + "error": logging.ERROR, + } + self.logger.log(severity.get(type, logging.INFO), message) + + def handle_read_event(self): + if self.accepting: + # accepting sockets are never connected, they "spawn" new + # sockets that are connected + self.handle_accept() + elif not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_read() + else: + self.handle_read() + + def handle_connect_event(self): + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise socket.error(err, _strerror(err)) + self.handle_connect() + self.connected = True + self.connecting = False + + def handle_write_event(self): + if self.accepting: + # Accepting sockets shouldn't get a write event. + # We will pretend it didn't happen. + return + + if not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_write() + + def handle_expt_event(self): + # handle_expt_event() is called if there might be an error on the + # socket, or if there is OOB data + # check for the error condition first + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # we can get here when select.select() says that there is an + # exceptional condition on the socket + # since there is an error, we'll go ahead and close the socket + # like we would in a subclassed handle_read() that received no + # data + self.handle_close() + else: + self.handle_expt() + + def handle_error(self): + nil, t, v, tbinfo = self.compact_traceback() + + # sometimes a user repr method will crash. + try: + self_repr = repr(self) + except: # pragma: no cover + self_repr = "<__repr__(self) failed for object at %0x>" % id(self) + + self.log_info( + "uncaptured python exception, closing channel %s (%s:%s %s)" + % (self_repr, t, v, tbinfo), + "error", + ) + self.handle_close() + + def handle_expt(self): + self.log_info("unhandled incoming priority event", "warning") + + def handle_read(self): + self.log_info("unhandled read event", "warning") + + def handle_write(self): + self.log_info("unhandled write event", "warning") + + def handle_connect(self): + self.log_info("unhandled connect event", "warning") + + def handle_accept(self): + pair = self.accept() + if pair is not None: + self.handle_accepted(*pair) + + def handle_accepted(self, sock, addr): + sock.close() + self.log_info("unhandled accepted event", "warning") + + def handle_close(self): + self.log_info("unhandled close event", "warning") + self.close() + + +# --------------------------------------------------------------------------- +# adds simple buffered output capability, useful for simple clients. +# [for more sophisticated usage use asynchat.async_chat] +# --------------------------------------------------------------------------- + + +class dispatcher_with_send(dispatcher): + def __init__(self, sock=None, map=None): + dispatcher.__init__(self, sock, map) + self.out_buffer = b"" + + def initiate_send(self): + num_sent = 0 + num_sent = dispatcher.send(self, self.out_buffer[:65536]) + self.out_buffer = self.out_buffer[num_sent:] + + handle_write = initiate_send + + def writable(self): + return (not self.connected) or len(self.out_buffer) + + def send(self, data): + if self.debug: # pragma: no cover + self.log_info("sending %s" % repr(data)) + self.out_buffer = self.out_buffer + data + self.initiate_send() + + +def close_all(map=None, ignore_all=False): + if map is None: # pragma: no cover + map = socket_map + for x in list(map.values()): # list() FBO py3 + try: + x.close() + except socket.error as x: + if x.args[0] == EBADF: + pass + elif not ignore_all: + raise + except _reraised_exceptions: + raise + except: + if not ignore_all: + raise + map.clear() + + +# Asynchronous File I/O: +# +# After a little research (reading man pages on various unixen, and +# digging through the linux kernel), I've determined that select() +# isn't meant for doing asynchronous file i/o. +# Heartening, though - reading linux/mm/filemap.c shows that linux +# supports asynchronous read-ahead. So _MOST_ of the time, the data +# will be sitting in memory for us already when we go to read it. +# +# What other OS's (besides NT) support async file i/o? [VMS?] +# +# Regardless, this is useful for pipes, and stdin/stdout... + +if os.name == "posix": + + class file_wrapper: + # Here we override just enough to make a file + # look like a socket for the purposes of asyncore. + # The passed fd is automatically os.dup()'d + + def __init__(self, fd): + self.fd = os.dup(fd) + + def __del__(self): + if self.fd >= 0: + warnings.warn("unclosed file %r" % self, compat.ResourceWarning) + self.close() + + def recv(self, *args): + return os.read(self.fd, *args) + + def send(self, *args): + return os.write(self.fd, *args) + + def getsockopt(self, level, optname, buflen=None): # pragma: no cover + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError( + "Only asyncore specific behaviour " "implemented." + ) + + read = recv + write = send + + def close(self): + if self.fd < 0: + return + fd = self.fd + self.fd = -1 + os.close(fd) + + def fileno(self): + return self.fd + + class file_dispatcher(dispatcher): + def __init__(self, fd, map=None): + dispatcher.__init__(self, None, map) + self.connected = True + try: + fd = fd.fileno() + except AttributeError: + pass + self.set_file(fd) + # set it to non-blocking mode + compat.set_nonblocking(fd) + + def set_file(self, fd): + self.socket = file_wrapper(fd) + self._fileno = self.socket.fileno() + self.add_channel() diff --git a/waitress/__init__.py b/waitress/__init__.py deleted file mode 100644 index e6e5911..0000000 --- a/waitress/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from waitress.server import create_server -import logging - - -def serve(app, **kw): - _server = kw.pop("_server", create_server) # test shim - _quiet = kw.pop("_quiet", False) # test shim - _profile = kw.pop("_profile", False) # test shim - if not _quiet: # pragma: no cover - # idempotent if logging has already been set up - logging.basicConfig() - server = _server(app, **kw) - if not _quiet: # pragma: no cover - server.print_listen("Serving on http://{}:{}") - if _profile: # pragma: no cover - profile("server.run()", globals(), locals(), (), False) - else: - server.run() - - -def serve_paste(app, global_conf, **kw): - serve(app, **kw) - return 0 - - -def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover - # runs a command under the profiler and print profiling output at shutdown - import os - import profile - import pstats - import tempfile - - fd, fn = tempfile.mkstemp() - try: - profile.runctx(cmd, globals, locals, fn) - stats = pstats.Stats(fn) - stats.strip_dirs() - # calls,time,cumulative and cumulative,calls,time are useful - stats.sort_stats(*(sort_order or ("cumulative", "calls", "time"))) - if callers: - stats.print_callers(0.3) - else: - stats.print_stats(0.3) - finally: - os.remove(fn) diff --git a/waitress/__main__.py b/waitress/__main__.py deleted file mode 100644 index 9bcd07e..0000000 --- a/waitress/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from waitress.runner import run # pragma nocover - -run() # pragma nocover diff --git a/waitress/adjustments.py b/waitress/adjustments.py deleted file mode 100644 index 93439ea..0000000 --- a/waitress/adjustments.py +++ /dev/null @@ -1,515 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Adjustments are tunable parameters. -""" -import getopt -import socket -import warnings - -from .proxy_headers import PROXY_HEADERS -from .compat import ( - PY2, - WIN, - string_types, - HAS_IPV6, -) - -truthy = frozenset(("t", "true", "y", "yes", "on", "1")) - -KNOWN_PROXY_HEADERS = frozenset( - header.lower().replace("_", "-") for header in PROXY_HEADERS -) - - -def asbool(s): - """ Return the boolean value ``True`` if the case-lowered value of string - input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise - return the boolean value ``False``. If ``s`` is the value ``None``, - return ``False``. If ``s`` is already one of the boolean values ``True`` - or ``False``, return it.""" - if s is None: - return False - if isinstance(s, bool): - return s - s = str(s).strip() - return s.lower() in truthy - - -def asoctal(s): - """Convert the given octal string to an actual number.""" - return int(s, 8) - - -def aslist_cronly(value): - if isinstance(value, string_types): - value = filter(None, [x.strip() for x in value.splitlines()]) - return list(value) - - -def aslist(value): - """ Return a list of strings, separating the input based on newlines - and, if flatten=True (the default), also split on spaces within - each line.""" - values = aslist_cronly(value) - result = [] - for value in values: - subvalues = value.split() - result.extend(subvalues) - return result - - -def asset(value): - return set(aslist(value)) - - -def slash_fixed_str(s): - s = s.strip() - if s: - # always have a leading slash, replace any number of leading slashes - # with a single slash, and strip any trailing slashes - s = "/" + s.lstrip("/").rstrip("/") - return s - - -def str_iftruthy(s): - return str(s) if s else None - - -def as_socket_list(sockets): - """Checks if the elements in the list are of type socket and - removes them if not.""" - return [sock for sock in sockets if isinstance(sock, socket.socket)] - - -class _str_marker(str): - pass - - -class _int_marker(int): - pass - - -class _bool_marker(object): - pass - - -class Adjustments(object): - """This class contains tunable parameters. - """ - - _params = ( - ("host", str), - ("port", int), - ("ipv4", asbool), - ("ipv6", asbool), - ("listen", aslist), - ("threads", int), - ("trusted_proxy", str_iftruthy), - ("trusted_proxy_count", int), - ("trusted_proxy_headers", asset), - ("log_untrusted_proxy_headers", asbool), - ("clear_untrusted_proxy_headers", asbool), - ("url_scheme", str), - ("url_prefix", slash_fixed_str), - ("backlog", int), - ("recv_bytes", int), - ("send_bytes", int), - ("outbuf_overflow", int), - ("outbuf_high_watermark", int), - ("inbuf_overflow", int), - ("connection_limit", int), - ("cleanup_interval", int), - ("channel_timeout", int), - ("log_socket_errors", asbool), - ("max_request_header_size", int), - ("max_request_body_size", int), - ("expose_tracebacks", asbool), - ("ident", str_iftruthy), - ("asyncore_loop_timeout", int), - ("asyncore_use_poll", asbool), - ("unix_socket", str), - ("unix_socket_perms", asoctal), - ("sockets", as_socket_list), - ) - - _param_map = dict(_params) - - # hostname or IP address to listen on - host = _str_marker("0.0.0.0") - - # TCP port to listen on - port = _int_marker(8080) - - listen = ["{}:{}".format(host, port)] - - # number of threads available for tasks - threads = 4 - - # Host allowed to overrid ``wsgi.url_scheme`` via header - trusted_proxy = None - - # How many proxies we trust when chained - # - # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]" - # - # or - # - # Forwarded: for=192.0.2.1, For="[2001:db8::1]" - # - # means there were (potentially), two proxies involved. If we know there is - # only 1 valid proxy, then that initial IP address "192.0.2.1" is not - # trusted and we completely ignore it. If there are two trusted proxies in - # the path, this value should be set to a higher number. - trusted_proxy_count = None - - # Which of the proxy headers should we trust, this is a set where you - # either specify forwarded or one or more of forwarded-host, forwarded-for, - # forwarded-proto, forwarded-port. - trusted_proxy_headers = set() - - # Would you like waitress to log warnings about untrusted proxy headers - # that were encountered while processing the proxy headers? This only makes - # sense to set when you have a trusted_proxy, and you expect the upstream - # proxy server to filter invalid headers - log_untrusted_proxy_headers = False - - # Should waitress clear any proxy headers that are not deemed trusted from - # the environ? Change to True by default in 2.x - clear_untrusted_proxy_headers = _bool_marker - - # default ``wsgi.url_scheme`` value - url_scheme = "http" - - # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO`` - # when nonempty - url_prefix = "" - - # server identity (sent in Server: header) - ident = "waitress" - - # backlog is the value waitress passes to pass to socket.listen() This is - # the maximum number of incoming TCP connections that will wait in an OS - # queue for an available channel. From listen(1): "If a connection - # request arrives when the queue is full, the client may receive an error - # with an indication of ECONNREFUSED or, if the underlying protocol - # supports retransmission, the request may be ignored so that a later - # reattempt at connection succeeds." - backlog = 1024 - - # recv_bytes is the argument to pass to socket.recv(). - recv_bytes = 8192 - - # deprecated setting controls how many bytes will be buffered before - # being flushed to the socket - send_bytes = 1 - - # A tempfile should be created if the pending output is larger than - # outbuf_overflow, which is measured in bytes. The default is 1MB. This - # is conservative. - outbuf_overflow = 1048576 - - # The app_iter will pause when pending output is larger than this value - # in bytes. - outbuf_high_watermark = 16777216 - - # A tempfile should be created if the pending input is larger than - # inbuf_overflow, which is measured in bytes. The default is 512K. This - # is conservative. - inbuf_overflow = 524288 - - # Stop creating new channels if too many are already active (integer). - # Each channel consumes at least one file descriptor, and, depending on - # the input and output body sizes, potentially up to three. The default - # is conservative, but you may need to increase the number of file - # descriptors available to the Waitress process on most platforms in - # order to safely change it (see ``ulimit -a`` "open files" setting). - # Note that this doesn't control the maximum number of TCP connections - # that can be waiting for processing; the ``backlog`` argument controls - # that. - connection_limit = 100 - - # Minimum seconds between cleaning up inactive channels. - cleanup_interval = 30 - - # Maximum seconds to leave an inactive connection open. - channel_timeout = 120 - - # Boolean: turn off to not log premature client disconnects. - log_socket_errors = True - - # maximum number of bytes of all request headers combined (256K default) - max_request_header_size = 262144 - - # maximum number of bytes in request body (1GB default) - max_request_body_size = 1073741824 - - # expose tracebacks of uncaught exceptions - expose_tracebacks = False - - # Path to a Unix domain socket to use. - unix_socket = None - - # Path to a Unix domain socket to use. - unix_socket_perms = 0o600 - - # The socket options to set on receiving a connection. It is a list of - # (level, optname, value) tuples. TCP_NODELAY disables the Nagle - # algorithm for writes (Waitress already buffers its writes). - socket_options = [ - (socket.SOL_TCP, socket.TCP_NODELAY, 1), - ] - - # The asyncore.loop timeout value - asyncore_loop_timeout = 1 - - # The asyncore.loop flag to use poll() instead of the default select(). - asyncore_use_poll = False - - # Enable IPv4 by default - ipv4 = True - - # Enable IPv6 by default - ipv6 = True - - # A list of sockets that waitress will use to accept connections. They can - # be used for e.g. socket activation - sockets = [] - - def __init__(self, **kw): - - if "listen" in kw and ("host" in kw or "port" in kw): - raise ValueError("host or port may not be set if listen is set.") - - if "listen" in kw and "sockets" in kw: - raise ValueError("socket may not be set if listen is set.") - - if "sockets" in kw and ("host" in kw or "port" in kw): - raise ValueError("host or port may not be set if sockets is set.") - - if "sockets" in kw and "unix_socket" in kw: - raise ValueError("unix_socket may not be set if sockets is set") - - if "unix_socket" in kw and ("host" in kw or "port" in kw): - raise ValueError("unix_socket may not be set if host or port is set") - - if "unix_socket" in kw and "listen" in kw: - raise ValueError("unix_socket may not be set if listen is set") - - if "send_bytes" in kw: - warnings.warn( - "send_bytes will be removed in a future release", DeprecationWarning, - ) - - for k, v in kw.items(): - if k not in self._param_map: - raise ValueError("Unknown adjustment %r" % k) - setattr(self, k, self._param_map[k](v)) - - if not isinstance(self.host, _str_marker) or not isinstance( - self.port, _int_marker - ): - self.listen = ["{}:{}".format(self.host, self.port)] - - enabled_families = socket.AF_UNSPEC - - if not self.ipv4 and not HAS_IPV6: # pragma: no cover - raise ValueError( - "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start." - ) - - if self.ipv4 and not self.ipv6: - enabled_families = socket.AF_INET - - if not self.ipv4 and self.ipv6 and HAS_IPV6: - enabled_families = socket.AF_INET6 - - wanted_sockets = [] - hp_pairs = [] - for i in self.listen: - if ":" in i: - (host, port) = i.rsplit(":", 1) - - # IPv6 we need to make sure that we didn't split on the address - if "]" in port: # pragma: nocover - (host, port) = (i, str(self.port)) - else: - (host, port) = (i, str(self.port)) - - if WIN and PY2: # pragma: no cover - try: - # Try turning the port into an integer - port = int(port) - - except Exception: - raise ValueError( - "Windows does not support service names instead of port numbers" - ) - - try: - if "[" in host and "]" in host: # pragma: nocover - host = host.strip("[").rstrip("]") - - if host == "*": - host = None - - for s in socket.getaddrinfo( - host, - port, - enabled_families, - socket.SOCK_STREAM, - socket.IPPROTO_TCP, - socket.AI_PASSIVE, - ): - (family, socktype, proto, _, sockaddr) = s - - # It seems that getaddrinfo() may sometimes happily return - # the same result multiple times, this of course makes - # bind() very unhappy... - # - # Split on %, and drop the zone-index from the host in the - # sockaddr. Works around a bug in OS X whereby - # getaddrinfo() returns the same link-local interface with - # two different zone-indices (which makes no sense what so - # ever...) yet treats them equally when we attempt to bind(). - if ( - sockaddr[1] == 0 - or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs - ): - wanted_sockets.append((family, socktype, proto, sockaddr)) - hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) - - except Exception: - raise ValueError("Invalid host/port specified.") - - if self.trusted_proxy_count is not None and self.trusted_proxy is None: - raise ValueError( - "trusted_proxy_count has no meaning without setting " "trusted_proxy" - ) - - elif self.trusted_proxy_count is None: - self.trusted_proxy_count = 1 - - if self.trusted_proxy_headers and self.trusted_proxy is None: - raise ValueError( - "trusted_proxy_headers has no meaning without setting " "trusted_proxy" - ) - - if self.trusted_proxy_headers: - self.trusted_proxy_headers = { - header.lower() for header in self.trusted_proxy_headers - } - - unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS - if unknown_values: - raise ValueError( - "Received unknown trusted_proxy_headers value (%s) expected one " - "of %s" - % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) - ) - - if ( - "forwarded" in self.trusted_proxy_headers - and self.trusted_proxy_headers - {"forwarded"} - ): - raise ValueError( - "The Forwarded proxy header and the " - "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually " - "exclusive. Can't trust both!" - ) - - elif self.trusted_proxy is not None: - warnings.warn( - "No proxy headers were marked as trusted, but trusted_proxy was set. " - "Implicitly trusting X-Forwarded-Proto for backwards compatibility. " - "This will be removed in future versions of waitress.", - DeprecationWarning, - ) - self.trusted_proxy_headers = {"x-forwarded-proto"} - - if self.clear_untrusted_proxy_headers is _bool_marker: - warnings.warn( - "In future versions of Waitress clear_untrusted_proxy_headers will be " - "set to True by default. You may opt-out by setting this value to " - "False, or opt-in explicitly by setting this to True.", - DeprecationWarning, - ) - self.clear_untrusted_proxy_headers = False - - self.listen = wanted_sockets - - self.check_sockets(self.sockets) - - @classmethod - def parse_args(cls, argv): - """Pre-parse command line arguments for input into __init__. Note that - this does not cast values into adjustment types, it just creates a - dictionary suitable for passing into __init__, where __init__ does the - casting. - """ - long_opts = ["help", "call"] - for opt, cast in cls._params: - opt = opt.replace("_", "-") - if cast is asbool: - long_opts.append(opt) - long_opts.append("no-" + opt) - else: - long_opts.append(opt + "=") - - kw = { - "help": False, - "call": False, - } - - opts, args = getopt.getopt(argv, "", long_opts) - for opt, value in opts: - param = opt.lstrip("-").replace("-", "_") - - if param == "listen": - kw["listen"] = "{} {}".format(kw.get("listen", ""), value) - continue - - if param.startswith("no_"): - param = param[3:] - kw[param] = "false" - elif param in ("help", "call"): - kw[param] = True - elif cls._param_map[param] is asbool: - kw[param] = "true" - else: - kw[param] = value - - return kw, args - - @classmethod - def check_sockets(cls, sockets): - has_unix_socket = False - has_inet_socket = False - has_unsupported_socket = False - for sock in sockets: - if ( - sock.family == socket.AF_INET or sock.family == socket.AF_INET6 - ) and sock.type == socket.SOCK_STREAM: - has_inet_socket = True - elif ( - hasattr(socket, "AF_UNIX") - and sock.family == socket.AF_UNIX - and sock.type == socket.SOCK_STREAM - ): - has_unix_socket = True - else: - has_unsupported_socket = True - if has_unix_socket and has_inet_socket: - raise ValueError("Internet and UNIX sockets may not be mixed.") - if has_unsupported_socket: - raise ValueError("Only Internet or UNIX stream sockets may be used.") diff --git a/waitress/buffers.py b/waitress/buffers.py deleted file mode 100644 index 04f6b42..0000000 --- a/waitress/buffers.py +++ /dev/null @@ -1,308 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001-2004 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Buffers -""" -from io import BytesIO - -# copy_bytes controls the size of temp. strings for shuffling data around. -COPY_BYTES = 1 << 18 # 256K - -# The maximum number of bytes to buffer in a simple string. -STRBUF_LIMIT = 8192 - - -class FileBasedBuffer(object): - - remain = 0 - - def __init__(self, file, from_buffer=None): - self.file = file - if from_buffer is not None: - from_file = from_buffer.getfile() - read_pos = from_file.tell() - from_file.seek(0) - while True: - data = from_file.read(COPY_BYTES) - if not data: - break - file.write(data) - self.remain = int(file.tell() - read_pos) - from_file.seek(read_pos) - file.seek(read_pos) - - def __len__(self): - return self.remain - - def __nonzero__(self): - return True - - __bool__ = __nonzero__ # py3 - - def append(self, s): - file = self.file - read_pos = file.tell() - file.seek(0, 2) - file.write(s) - file.seek(read_pos) - self.remain = self.remain + len(s) - - def get(self, numbytes=-1, skip=False): - file = self.file - if not skip: - read_pos = file.tell() - if numbytes < 0: - # Read all - res = file.read() - else: - res = file.read(numbytes) - if skip: - self.remain -= len(res) - else: - file.seek(read_pos) - return res - - def skip(self, numbytes, allow_prune=0): - if self.remain < numbytes: - raise ValueError( - "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain) - ) - self.file.seek(numbytes, 1) - self.remain = self.remain - numbytes - - def newfile(self): - raise NotImplementedError() - - def prune(self): - file = self.file - if self.remain == 0: - read_pos = file.tell() - file.seek(0, 2) - sz = file.tell() - file.seek(read_pos) - if sz == 0: - # Nothing to prune. - return - nf = self.newfile() - while True: - data = file.read(COPY_BYTES) - if not data: - break - nf.write(data) - self.file = nf - - def getfile(self): - return self.file - - def close(self): - if hasattr(self.file, "close"): - self.file.close() - self.remain = 0 - - -class TempfileBasedBuffer(FileBasedBuffer): - def __init__(self, from_buffer=None): - FileBasedBuffer.__init__(self, self.newfile(), from_buffer) - - def newfile(self): - from tempfile import TemporaryFile - - return TemporaryFile("w+b") - - -class BytesIOBasedBuffer(FileBasedBuffer): - def __init__(self, from_buffer=None): - if from_buffer is not None: - FileBasedBuffer.__init__(self, BytesIO(), from_buffer) - else: - # Shortcut. :-) - self.file = BytesIO() - - def newfile(self): - return BytesIO() - - -def _is_seekable(fp): - if hasattr(fp, "seekable"): - return fp.seekable() - return hasattr(fp, "seek") and hasattr(fp, "tell") - - -class ReadOnlyFileBasedBuffer(FileBasedBuffer): - # used as wsgi.file_wrapper - - def __init__(self, file, block_size=32768): - self.file = file - self.block_size = block_size # for __iter__ - - def prepare(self, size=None): - if _is_seekable(self.file): - start_pos = self.file.tell() - self.file.seek(0, 2) - end_pos = self.file.tell() - self.file.seek(start_pos) - fsize = end_pos - start_pos - if size is None: - self.remain = fsize - else: - self.remain = min(fsize, size) - return self.remain - - def get(self, numbytes=-1, skip=False): - # never read more than self.remain (it can be user-specified) - if numbytes == -1 or numbytes > self.remain: - numbytes = self.remain - file = self.file - if not skip: - read_pos = file.tell() - res = file.read(numbytes) - if skip: - self.remain -= len(res) - else: - file.seek(read_pos) - return res - - def __iter__(self): # called by task if self.filelike has no seek/tell - return self - - def next(self): - val = self.file.read(self.block_size) - if not val: - raise StopIteration - return val - - __next__ = next # py3 - - def append(self, s): - raise NotImplementedError - - -class OverflowableBuffer(object): - """ - This buffer implementation has four stages: - - No data - - Bytes-based buffer - - BytesIO-based buffer - - Temporary file storage - The first two stages are fastest for simple transfers. - """ - - overflowed = False - buf = None - strbuf = b"" # Bytes-based buffer. - - def __init__(self, overflow): - # overflow is the maximum to be stored in a StringIO buffer. - self.overflow = overflow - - def __len__(self): - buf = self.buf - if buf is not None: - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - return buf.__len__() - else: - return self.strbuf.__len__() - - def __nonzero__(self): - # use self.__len__ rather than len(self) FBO of not getting - # OverflowError on Python 2 - return self.__len__() > 0 - - __bool__ = __nonzero__ # py3 - - def _create_buffer(self): - strbuf = self.strbuf - if len(strbuf) >= self.overflow: - self._set_large_buffer() - else: - self._set_small_buffer() - buf = self.buf - if strbuf: - buf.append(self.strbuf) - self.strbuf = b"" - return buf - - def _set_small_buffer(self): - self.buf = BytesIOBasedBuffer(self.buf) - self.overflowed = False - - def _set_large_buffer(self): - self.buf = TempfileBasedBuffer(self.buf) - self.overflowed = True - - def append(self, s): - buf = self.buf - if buf is None: - strbuf = self.strbuf - if len(strbuf) + len(s) < STRBUF_LIMIT: - self.strbuf = strbuf + s - return - buf = self._create_buffer() - buf.append(s) - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - sz = buf.__len__() - if not self.overflowed: - if sz >= self.overflow: - self._set_large_buffer() - - def get(self, numbytes=-1, skip=False): - buf = self.buf - if buf is None: - strbuf = self.strbuf - if not skip: - return strbuf - buf = self._create_buffer() - return buf.get(numbytes, skip) - - def skip(self, numbytes, allow_prune=False): - buf = self.buf - if buf is None: - if allow_prune and numbytes == len(self.strbuf): - # We could slice instead of converting to - # a buffer, but that would eat up memory in - # large transfers. - self.strbuf = b"" - return - buf = self._create_buffer() - buf.skip(numbytes, allow_prune) - - def prune(self): - """ - A potentially expensive operation that removes all data - already retrieved from the buffer. - """ - buf = self.buf - if buf is None: - self.strbuf = b"" - return - buf.prune() - if self.overflowed: - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - sz = buf.__len__() - if sz < self.overflow: - # Revert to a faster buffer. - self._set_small_buffer() - - def getfile(self): - buf = self.buf - if buf is None: - buf = self._create_buffer() - return buf.getfile() - - def close(self): - buf = self.buf - if buf is not None: - buf.close() diff --git a/waitress/channel.py b/waitress/channel.py deleted file mode 100644 index a8bc76f..0000000 --- a/waitress/channel.py +++ /dev/null @@ -1,414 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -import socket -import threading -import time -import traceback - -from waitress.buffers import ( - OverflowableBuffer, - ReadOnlyFileBasedBuffer, -) - -from waitress.parser import HTTPRequestParser - -from waitress.task import ( - ErrorTask, - WSGITask, -) - -from waitress.utilities import InternalServerError - -from . import wasyncore - - -class ClientDisconnected(Exception): - """ Raised when attempting to write to a closed socket.""" - - -class HTTPChannel(wasyncore.dispatcher, object): - """ - Setting self.requests = [somerequest] prevents more requests from being - received until the out buffers have been flushed. - - Setting self.requests = [] allows more requests to be received. - """ - - task_class = WSGITask - error_task_class = ErrorTask - parser_class = HTTPRequestParser - - request = None # A request parser instance - last_activity = 0 # Time of last activity - will_close = False # set to True to close the socket. - close_when_flushed = False # set to True to close the socket when flushed - requests = () # currently pending requests - sent_continue = False # used as a latch after sending 100 continue - total_outbufs_len = 0 # total bytes ready to send - current_outbuf_count = 0 # total bytes written to current outbuf - - # - # ASYNCHRONOUS METHODS (including __init__) - # - - def __init__( - self, server, sock, addr, adj, map=None, - ): - self.server = server - self.adj = adj - self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)] - self.creation_time = self.last_activity = time.time() - self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) - - # task_lock used to push/pop requests - self.task_lock = threading.Lock() - # outbuf_lock used to access any outbuf (expected to use an RLock) - self.outbuf_lock = threading.Condition() - - wasyncore.dispatcher.__init__(self, sock, map=map) - - # Don't let wasyncore.dispatcher throttle self.addr on us. - self.addr = addr - - def writable(self): - # if there's data in the out buffer or we've been instructed to close - # the channel (possibly by our server maintenance logic), run - # handle_write - return self.total_outbufs_len or self.will_close or self.close_when_flushed - - def handle_write(self): - # Precondition: there's data in the out buffer to be sent, or - # there's a pending will_close request - if not self.connected: - # we dont want to close the channel twice - return - - # try to flush any pending output - if not self.requests: - # 1. There are no running tasks, so we don't need to try to lock - # the outbuf before sending - # 2. The data in the out buffer should be sent as soon as possible - # because it's either data left over from task output - # or a 100 Continue line sent within "received". - flush = self._flush_some - elif self.total_outbufs_len >= self.adj.send_bytes: - # 1. There's a running task, so we need to try to lock - # the outbuf before sending - # 2. Only try to send if the data in the out buffer is larger - # than self.adj_bytes to avoid TCP fragmentation - flush = self._flush_some_if_lockable - else: - # 1. There's not enough data in the out buffer to bother to send - # right now. - flush = None - - if flush: - try: - flush() - except socket.error: - if self.adj.log_socket_errors: - self.logger.exception("Socket error") - self.will_close = True - except Exception: - self.logger.exception("Unexpected exception when flushing") - self.will_close = True - - if self.close_when_flushed and not self.total_outbufs_len: - self.close_when_flushed = False - self.will_close = True - - if self.will_close: - self.handle_close() - - def readable(self): - # We might want to create a new task. We can only do this if: - # 1. We're not already about to close the connection. - # 2. There's no already currently running task(s). - # 3. There's no data in the output buffer that needs to be sent - # before we potentially create a new task. - return not (self.will_close or self.requests or self.total_outbufs_len) - - def handle_read(self): - try: - data = self.recv(self.adj.recv_bytes) - except socket.error: - if self.adj.log_socket_errors: - self.logger.exception("Socket error") - self.handle_close() - return - if data: - self.last_activity = time.time() - self.received(data) - - def received(self, data): - """ - Receives input asynchronously and assigns one or more requests to the - channel. - """ - # Preconditions: there's no task(s) already running - request = self.request - requests = [] - - if not data: - return False - - while data: - if request is None: - request = self.parser_class(self.adj) - n = request.received(data) - if request.expect_continue and request.headers_finished: - # guaranteed by parser to be a 1.1 request - request.expect_continue = False - if not self.sent_continue: - # there's no current task, so we don't need to try to - # lock the outbuf to append to it. - outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" - self.outbufs[-1].append(outbuf_payload) - self.current_outbuf_count += len(outbuf_payload) - self.total_outbufs_len += len(outbuf_payload) - self.sent_continue = True - self._flush_some() - request.completed = False - if request.completed: - # The request (with the body) is ready to use. - self.request = None - if not request.empty: - requests.append(request) - request = None - else: - self.request = request - if n >= len(data): - break - data = data[n:] - - if requests: - self.requests = requests - self.server.add_task(self) - - return True - - def _flush_some_if_lockable(self): - # Since our task may be appending to the outbuf, we try to acquire - # the lock, but we don't block if we can't. - if self.outbuf_lock.acquire(False): - try: - self._flush_some() - - if self.total_outbufs_len < self.adj.outbuf_high_watermark: - self.outbuf_lock.notify() - finally: - self.outbuf_lock.release() - - def _flush_some(self): - # Send as much data as possible to our client - - sent = 0 - dobreak = False - - while True: - outbuf = self.outbufs[0] - # use outbuf.__len__ rather than len(outbuf) FBO of not getting - # OverflowError on 32-bit Python - outbuflen = outbuf.__len__() - while outbuflen > 0: - chunk = outbuf.get(self.sendbuf_len) - num_sent = self.send(chunk) - if num_sent: - outbuf.skip(num_sent, True) - outbuflen -= num_sent - sent += num_sent - self.total_outbufs_len -= num_sent - else: - # failed to write anything, break out entirely - dobreak = True - break - else: - # self.outbufs[-1] must always be a writable outbuf - if len(self.outbufs) > 1: - toclose = self.outbufs.pop(0) - try: - toclose.close() - except Exception: - self.logger.exception("Unexpected error when closing an outbuf") - else: - # caught up, done flushing for now - dobreak = True - - if dobreak: - break - - if sent: - self.last_activity = time.time() - return True - - return False - - def handle_close(self): - with self.outbuf_lock: - for outbuf in self.outbufs: - try: - outbuf.close() - except Exception: - self.logger.exception( - "Unknown exception while trying to close outbuf" - ) - self.total_outbufs_len = 0 - self.connected = False - self.outbuf_lock.notify() - wasyncore.dispatcher.close(self) - - def add_channel(self, map=None): - """See wasyncore.dispatcher - - This hook keeps track of opened channels. - """ - wasyncore.dispatcher.add_channel(self, map) - self.server.active_channels[self._fileno] = self - - def del_channel(self, map=None): - """See wasyncore.dispatcher - - This hook keeps track of closed channels. - """ - fd = self._fileno # next line sets this to None - wasyncore.dispatcher.del_channel(self, map) - ac = self.server.active_channels - if fd in ac: - del ac[fd] - - # - # SYNCHRONOUS METHODS - # - - def write_soon(self, data): - if not self.connected: - # if the socket is closed then interrupt the task so that it - # can cleanup possibly before the app_iter is exhausted - raise ClientDisconnected - if data: - # the async mainloop might be popping data off outbuf; we can - # block here waiting for it because we're in a task thread - with self.outbuf_lock: - self._flush_outbufs_below_high_watermark() - if not self.connected: - raise ClientDisconnected - num_bytes = len(data) - if data.__class__ is ReadOnlyFileBasedBuffer: - # they used wsgi.file_wrapper - self.outbufs.append(data) - nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) - self.outbufs.append(nextbuf) - self.current_outbuf_count = 0 - else: - if self.current_outbuf_count > self.adj.outbuf_high_watermark: - # rotate to a new buffer if the current buffer has hit - # the watermark to avoid it growing unbounded - nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) - self.outbufs.append(nextbuf) - self.current_outbuf_count = 0 - self.outbufs[-1].append(data) - self.current_outbuf_count += num_bytes - self.total_outbufs_len += num_bytes - if self.total_outbufs_len >= self.adj.send_bytes: - self.server.pull_trigger() - return num_bytes - return 0 - - def _flush_outbufs_below_high_watermark(self): - # check first to avoid locking if possible - if self.total_outbufs_len > self.adj.outbuf_high_watermark: - with self.outbuf_lock: - while ( - self.connected - and self.total_outbufs_len > self.adj.outbuf_high_watermark - ): - self.server.pull_trigger() - self.outbuf_lock.wait() - - def service(self): - """Execute all pending requests """ - with self.task_lock: - while self.requests: - request = self.requests[0] - if request.error: - task = self.error_task_class(self, request) - else: - task = self.task_class(self, request) - try: - task.service() - except ClientDisconnected: - self.logger.info( - "Client disconnected while serving %s" % task.request.path - ) - task.close_on_finish = True - except Exception: - self.logger.exception( - "Exception while serving %s" % task.request.path - ) - if not task.wrote_header: - if self.adj.expose_tracebacks: - body = traceback.format_exc() - else: - body = ( - "The server encountered an unexpected " - "internal server error" - ) - req_version = request.version - req_headers = request.headers - request = self.parser_class(self.adj) - request.error = InternalServerError(body) - # copy some original request attributes to fulfill - # HTTP 1.1 requirements - request.version = req_version - try: - request.headers["CONNECTION"] = req_headers["CONNECTION"] - except KeyError: - pass - task = self.error_task_class(self, request) - try: - task.service() # must not fail - except ClientDisconnected: - task.close_on_finish = True - else: - task.close_on_finish = True - # we cannot allow self.requests to drop to empty til - # here; otherwise the mainloop gets confused - if task.close_on_finish: - self.close_when_flushed = True - for request in self.requests: - request.close() - self.requests = [] - else: - # before processing a new request, ensure there is not too - # much data in the outbufs waiting to be flushed - # NB: currently readable() returns False while we are - # flushing data so we know no new requests will come in - # that we need to account for, otherwise it'd be better - # to do this check at the start of the request instead of - # at the end to account for consecutive service() calls - if len(self.requests) > 1: - self._flush_outbufs_below_high_watermark() - request = self.requests.pop(0) - request.close() - - if self.connected: - self.server.pull_trigger() - self.last_activity = time.time() - - def cancel(self): - """ Cancels all pending / active requests """ - self.will_close = True - self.connected = False - self.last_activity = time.time() - self.requests = [] diff --git a/waitress/compat.py b/waitress/compat.py deleted file mode 100644 index fe72a76..0000000 --- a/waitress/compat.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -import sys -import types -import platform -import warnings - -try: - import urlparse -except ImportError: # pragma: no cover - from urllib import parse as urlparse - -try: - import fcntl -except ImportError: # pragma: no cover - fcntl = None # windows - -# True if we are running on Python 3. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -# True if we are running on Windows -WIN = platform.system() == "Windows" - -if PY3: # pragma: no cover - string_types = (str,) - integer_types = (int,) - class_types = (type,) - text_type = str - binary_type = bytes - long = int -else: - string_types = (basestring,) - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - long = long - -if PY3: # pragma: no cover - from urllib.parse import unquote_to_bytes - - def unquote_bytes_to_wsgi(bytestring): - return unquote_to_bytes(bytestring).decode("latin-1") - - -else: - from urlparse import unquote as unquote_to_bytes - - def unquote_bytes_to_wsgi(bytestring): - return unquote_to_bytes(bytestring) - - -def text_(s, encoding="latin-1", errors="strict"): - """ If ``s`` is an instance of ``binary_type``, return - ``s.decode(encoding, errors)``, otherwise return ``s``""" - if isinstance(s, binary_type): - return s.decode(encoding, errors) - return s # pragma: no cover - - -if PY3: # pragma: no cover - - def tostr(s): - if isinstance(s, text_type): - s = s.encode("latin-1") - return str(s, "latin-1", "strict") - - def tobytes(s): - return bytes(s, "latin-1") - - -else: - tostr = str - - def tobytes(s): - return s - - -if PY3: # pragma: no cover - import builtins - - exec_ = getattr(builtins, "exec") - - def reraise(tp, value, tb=None): - if value is None: - value = tp - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - - del builtins - -else: # pragma: no cover - - def exec_(code, globs=None, locs=None): - """Execute code in a namespace.""" - if globs is None: - frame = sys._getframe(1) - globs = frame.f_globals - if locs is None: - locs = frame.f_locals - del frame - elif locs is None: - locs = globs - exec("""exec code in globs, locs""") - - exec_( - """def reraise(tp, value, tb=None): - raise tp, value, tb -""" - ) - -try: - from StringIO import StringIO as NativeIO -except ImportError: # pragma: no cover - from io import StringIO as NativeIO - -try: - import httplib -except ImportError: # pragma: no cover - from http import client as httplib - -try: - MAXINT = sys.maxint -except AttributeError: # pragma: no cover - MAXINT = sys.maxsize - - -# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, -# Python on Windows may not define IPPROTO_IPV6 in socket. -import socket - -HAS_IPV6 = socket.has_ipv6 - -if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"): - IPPROTO_IPV6 = socket.IPPROTO_IPV6 - IPV6_V6ONLY = socket.IPV6_V6ONLY -else: # pragma: no cover - if WIN: - IPPROTO_IPV6 = 41 - IPV6_V6ONLY = 27 - else: - warnings.warn( - "OS does not support required IPv6 socket flags. This is requirement " - "for Waitress. Please open an issue at https://github.com/Pylons/waitress. " - "IPv6 support has been disabled.", - RuntimeWarning, - ) - HAS_IPV6 = False - - -def set_nonblocking(fd): # pragma: no cover - if PY3 and sys.version_info[1] >= 5: - os.set_blocking(fd, False) - elif fcntl is None: - raise RuntimeError("no fcntl module present") - else: - flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(fd, fcntl.F_SETFL, flags) - - -if PY3: - ResourceWarning = ResourceWarning -else: - ResourceWarning = UserWarning - - -def qualname(cls): - if PY3: - return cls.__qualname__ - return cls.__name__ - - -try: - import thread -except ImportError: - # py3 - import _thread as thread diff --git a/waitress/parser.py b/waitress/parser.py deleted file mode 100644 index 53072b5..0000000 --- a/waitress/parser.py +++ /dev/null @@ -1,413 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""HTTP Request Parser - -This server uses asyncore to accept connections and do initial -processing but threads to do work. -""" -import re -from io import BytesIO - -from waitress.buffers import OverflowableBuffer -from waitress.compat import tostr, unquote_bytes_to_wsgi, urlparse -from waitress.receiver import ChunkedReceiver, FixedStreamReceiver -from waitress.utilities import ( - BadRequest, - RequestEntityTooLarge, - RequestHeaderFieldsTooLarge, - ServerNotImplemented, - find_double_newline, -) -from .rfc7230 import HEADER_FIELD - - -class ParsingError(Exception): - pass - - -class TransferEncodingNotImplemented(Exception): - pass - -class HTTPRequestParser(object): - """A structure that collects the HTTP request. - - Once the stream is completed, the instance is passed to - a server task constructor. - """ - - completed = False # Set once request is completed. - empty = False # Set if no request was made. - expect_continue = False # client sent "Expect: 100-continue" header - headers_finished = False # True when headers have been read - header_plus = b"" - chunked = False - content_length = 0 - header_bytes_received = 0 - body_bytes_received = 0 - body_rcv = None - version = "1.0" - error = None - connection_close = False - - # Other attributes: first_line, header, headers, command, uri, version, - # path, query, fragment - - def __init__(self, adj): - """ - adj is an Adjustments object. - """ - # headers is a mapping containing keys translated to uppercase - # with dashes turned into underscores. - self.headers = {} - self.adj = adj - - def received(self, data): - """ - Receives the HTTP stream for one request. Returns the number of - bytes consumed. Sets the completed flag once both the header and the - body have been received. - """ - if self.completed: - return 0 # Can't consume any more. - - datalen = len(data) - br = self.body_rcv - if br is None: - # In header. - max_header = self.adj.max_request_header_size - - s = self.header_plus + data - index = find_double_newline(s) - consumed = 0 - - if index >= 0: - # If the headers have ended, and we also have part of the body - # message in data we still want to validate we aren't going - # over our limit for received headers. - self.header_bytes_received += index - consumed = datalen - (len(s) - index) - else: - self.header_bytes_received += datalen - consumed = datalen - - # If the first line + headers is over the max length, we return a - # RequestHeaderFieldsTooLarge error rather than continuing to - # attempt to parse the headers. - if self.header_bytes_received >= max_header: - self.parse_header(b"GET / HTTP/1.0\r\n") - self.error = RequestHeaderFieldsTooLarge( - "exceeds max_header of %s" % max_header - ) - self.completed = True - return consumed - - if index >= 0: - # Header finished. - header_plus = s[:index] - - # Remove preceeding blank lines. This is suggested by - # https://tools.ietf.org/html/rfc7230#section-3.5 to support - # clients sending an extra CR LF after another request when - # using HTTP pipelining - header_plus = header_plus.lstrip() - - if not header_plus: - self.empty = True - self.completed = True - else: - try: - self.parse_header(header_plus) - except ParsingError as e: - self.error = BadRequest(e.args[0]) - self.completed = True - except TransferEncodingNotImplemented as e: - self.error = ServerNotImplemented(e.args[0]) - self.completed = True - else: - if self.body_rcv is None: - # no content-length header and not a t-e: chunked - # request - self.completed = True - - if self.content_length > 0: - max_body = self.adj.max_request_body_size - # we won't accept this request if the content-length - # is too large - - if self.content_length >= max_body: - self.error = RequestEntityTooLarge( - "exceeds max_body of %s" % max_body - ) - self.completed = True - self.headers_finished = True - - return consumed - - # Header not finished yet. - self.header_plus = s - - return datalen - else: - # In body. - consumed = br.received(data) - self.body_bytes_received += consumed - max_body = self.adj.max_request_body_size - - if self.body_bytes_received >= max_body: - # this will only be raised during t-e: chunked requests - self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body) - self.completed = True - elif br.error: - # garbage in chunked encoding input probably - self.error = br.error - self.completed = True - elif br.completed: - # The request (with the body) is ready to use. - self.completed = True - - if self.chunked: - # We've converted the chunked transfer encoding request - # body into a normal request body, so we know its content - # length; set the header here. We already popped the - # TRANSFER_ENCODING header in parse_header, so this will - # appear to the client to be an entirely non-chunked HTTP - # request with a valid content-length. - self.headers["CONTENT_LENGTH"] = str(br.__len__()) - - return consumed - - def parse_header(self, header_plus): - """ - Parses the header_plus block of text (the headers plus the - first line of the request). - """ - index = header_plus.find(b"\r\n") - if index >= 0: - first_line = header_plus[:index].rstrip() - header = header_plus[index + 2 :] - else: - raise ParsingError("HTTP message header invalid") - - if b"\r" in first_line or b"\n" in first_line: - raise ParsingError("Bare CR or LF found in HTTP message") - - self.first_line = first_line # for testing - - lines = get_header_lines(header) - - headers = self.headers - for line in lines: - header = HEADER_FIELD.match(line) - - if not header: - raise ParsingError("Invalid header") - - key, value = header.group("name", "value") - - if b"_" in key: - # TODO(xistence): Should we drop this request instead? - continue - - # Only strip off whitespace that is considered valid whitespace by - # RFC7230, don't strip the rest - value = value.strip(b" \t") - key1 = tostr(key.upper().replace(b"-", b"_")) - # If a header already exists, we append subsequent values - # separated by a comma. Applications already need to handle - # the comma separated values, as HTTP front ends might do - # the concatenation for you (behavior specified in RFC2616). - try: - headers[key1] += tostr(b", " + value) - except KeyError: - headers[key1] = tostr(value) - - # command, uri, version will be bytes - command, uri, version = crack_first_line(first_line) - version = tostr(version) - command = tostr(command) - self.command = command - self.version = version - ( - self.proxy_scheme, - self.proxy_netloc, - self.path, - self.query, - self.fragment, - ) = split_uri(uri) - self.url_scheme = self.adj.url_scheme - connection = headers.get("CONNECTION", "") - - if version == "1.0": - if connection.lower() != "keep-alive": - self.connection_close = True - - if version == "1.1": - # since the server buffers data from chunked transfers and clients - # never need to deal with chunked requests, downstream clients - # should not see the HTTP_TRANSFER_ENCODING header; we pop it - # here - te = headers.pop("TRANSFER_ENCODING", "") - - # NB: We can not just call bare strip() here because it will also - # remove other non-printable characters that we explicitly do not - # want removed so that if someone attempts to smuggle a request - # with these characters we don't fall prey to it. - # - # For example \x85 is stripped by default, but it is not considered - # valid whitespace to be stripped by RFC7230. - encodings = [ - encoding.strip(" \t").lower() for encoding in te.split(",") if encoding - ] - - for encoding in encodings: - # Out of the transfer-codings listed in - # https://tools.ietf.org/html/rfc7230#section-4 we only support - # chunked at this time. - - # Note: the identity transfer-coding was removed in RFC7230: - # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus - # not supported - if encoding not in {"chunked"}: - raise TransferEncodingNotImplemented( - "Transfer-Encoding requested is not supported." - ) - - if encodings and encodings[-1] == "chunked": - self.chunked = True - buf = OverflowableBuffer(self.adj.inbuf_overflow) - self.body_rcv = ChunkedReceiver(buf) - elif encodings: # pragma: nocover - raise TransferEncodingNotImplemented( - "Transfer-Encoding requested is not supported." - ) - - expect = headers.get("EXPECT", "").lower() - self.expect_continue = expect == "100-continue" - if connection.lower() == "close": - self.connection_close = True - - if not self.chunked: - try: - cl = int(headers.get("CONTENT_LENGTH", 0)) - except ValueError: - raise ParsingError("Content-Length is invalid") - - self.content_length = cl - if cl > 0: - buf = OverflowableBuffer(self.adj.inbuf_overflow) - self.body_rcv = FixedStreamReceiver(cl, buf) - - def get_body_stream(self): - body_rcv = self.body_rcv - if body_rcv is not None: - return body_rcv.getfile() - else: - return BytesIO() - - def close(self): - body_rcv = self.body_rcv - if body_rcv is not None: - body_rcv.getbuf().close() - - -def split_uri(uri): - # urlsplit handles byte input by returning bytes on py3, so - # scheme, netloc, path, query, and fragment are bytes - - scheme = netloc = path = query = fragment = b"" - - # urlsplit below will treat this as a scheme-less netloc, thereby losing - # the original intent of the request. Here we shamelessly stole 4 lines of - # code from the CPython stdlib to parse out the fragment and query but - # leave the path alone. See - # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468 - # and https://github.com/Pylons/waitress/issues/260 - - if uri[:2] == b"//": - path = uri - - if b"#" in path: - path, fragment = path.split(b"#", 1) - - if b"?" in path: - path, query = path.split(b"?", 1) - else: - try: - scheme, netloc, path, query, fragment = urlparse.urlsplit(uri) - except UnicodeError: - raise ParsingError("Bad URI") - - return ( - tostr(scheme), - tostr(netloc), - unquote_bytes_to_wsgi(path), - tostr(query), - tostr(fragment), - ) - - -def get_header_lines(header): - """ - Splits the header into lines, putting multi-line headers together. - """ - r = [] - lines = header.split(b"\r\n") - for line in lines: - if not line: - continue - - if b"\r" in line or b"\n" in line: - raise ParsingError('Bare CR or LF found in header line "%s"' % tostr(line)) - - if line.startswith((b" ", b"\t")): - if not r: - # https://corte.si/posts/code/pathod/pythonservers/index.html - raise ParsingError('Malformed header line "%s"' % tostr(line)) - r[-1] += line - else: - r.append(line) - return r - - -first_line_re = re.compile( - b"([^ ]+) " - b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)" - b"(( HTTP/([0-9.]+))$|$)" -) - - -def crack_first_line(line): - m = first_line_re.match(line) - if m is not None and m.end() == len(line): - if m.group(3): - version = m.group(5) - else: - version = b"" - method = m.group(1) - - # the request methods that are currently defined are all uppercase: - # https://www.iana.org/assignments/http-methods/http-methods.xhtml and - # the request method is case sensitive according to - # https://tools.ietf.org/html/rfc7231#section-4.1 - - # By disallowing anything but uppercase methods we save poor - # unsuspecting souls from sending lowercase HTTP methods to waitress - # and having the request complete, while servers like nginx drop the - # request onto the floor. - if method != method.upper(): - raise ParsingError('Malformed HTTP method "%s"' % tostr(method)) - uri = m.group(2) - return method, uri, version - else: - return b"", b"", b"" diff --git a/waitress/proxy_headers.py b/waitress/proxy_headers.py deleted file mode 100644 index 1df8b8e..0000000 --- a/waitress/proxy_headers.py +++ /dev/null @@ -1,333 +0,0 @@ -from collections import namedtuple - -from .utilities import logger, undquote, BadRequest - - -PROXY_HEADERS = frozenset( - { - "X_FORWARDED_FOR", - "X_FORWARDED_HOST", - "X_FORWARDED_PROTO", - "X_FORWARDED_PORT", - "X_FORWARDED_BY", - "FORWARDED", - } -) - -Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"]) - - -class MalformedProxyHeader(Exception): - def __init__(self, header, reason, value): - self.header = header - self.reason = reason - self.value = value - super(MalformedProxyHeader, self).__init__(header, reason, value) - - -def proxy_headers_middleware( - app, - trusted_proxy=None, - trusted_proxy_count=1, - trusted_proxy_headers=None, - clear_untrusted=True, - log_untrusted=False, - logger=logger, -): - def translate_proxy_headers(environ, start_response): - untrusted_headers = PROXY_HEADERS - remote_peer = environ["REMOTE_ADDR"] - if trusted_proxy == "*" or remote_peer == trusted_proxy: - try: - untrusted_headers = parse_proxy_headers( - environ, - trusted_proxy_count=trusted_proxy_count, - trusted_proxy_headers=trusted_proxy_headers, - logger=logger, - ) - except MalformedProxyHeader as ex: - logger.warning( - 'Malformed proxy header "%s" from "%s": %s value: %s', - ex.header, - remote_peer, - ex.reason, - ex.value, - ) - error = BadRequest('Header "{0}" malformed.'.format(ex.header)) - return error.wsgi_response(environ, start_response) - - # Clear out the untrusted proxy headers - if clear_untrusted: - clear_untrusted_headers( - environ, untrusted_headers, log_warning=log_untrusted, logger=logger, - ) - - return app(environ, start_response) - - return translate_proxy_headers - - -def parse_proxy_headers( - environ, trusted_proxy_count, trusted_proxy_headers, logger=logger, -): - if trusted_proxy_headers is None: - trusted_proxy_headers = set() - - forwarded_for = [] - forwarded_host = forwarded_proto = forwarded_port = forwarded = "" - client_addr = None - untrusted_headers = set(PROXY_HEADERS) - - def raise_for_multiple_values(): - raise ValueError("Unspecified behavior for multiple values found in header",) - - if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ: - try: - forwarded_for = [] - - for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","): - forward_hop = forward_hop.strip() - forward_hop = undquote(forward_hop) - - # Make sure that all IPv6 addresses are surrounded by brackets, - # this is assuming that the IPv6 representation here does not - # include a port number. - - if "." not in forward_hop and ( - ":" in forward_hop and forward_hop[-1] != "]" - ): - forwarded_for.append("[{}]".format(forward_hop)) - else: - forwarded_for.append(forward_hop) - - forwarded_for = forwarded_for[-trusted_proxy_count:] - client_addr = forwarded_for[0] - - untrusted_headers.remove("X_FORWARDED_FOR") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"], - ) - - if ( - "x-forwarded-host" in trusted_proxy_headers - and "HTTP_X_FORWARDED_HOST" in environ - ): - try: - forwarded_host_multiple = [] - - for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","): - forward_host = forward_host.strip() - forward_host = undquote(forward_host) - forwarded_host_multiple.append(forward_host) - - forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:] - forwarded_host = forwarded_host_multiple[0] - - untrusted_headers.remove("X_FORWARDED_HOST") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"], - ) - - if "x-forwarded-proto" in trusted_proxy_headers: - try: - forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", "")) - if "," in forwarded_proto: - raise_for_multiple_values() - untrusted_headers.remove("X_FORWARDED_PROTO") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"], - ) - - if "x-forwarded-port" in trusted_proxy_headers: - try: - forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", "")) - if "," in forwarded_port: - raise_for_multiple_values() - untrusted_headers.remove("X_FORWARDED_PORT") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"], - ) - - if "x-forwarded-by" in trusted_proxy_headers: - # Waitress itself does not use X-Forwarded-By, but we can not - # remove it so it can get set in the environ - untrusted_headers.remove("X_FORWARDED_BY") - - if "forwarded" in trusted_proxy_headers: - forwarded = environ.get("HTTP_FORWARDED", None) - untrusted_headers = PROXY_HEADERS - {"FORWARDED"} - - # If the Forwarded header exists, it gets priority - if forwarded: - proxies = [] - try: - for forwarded_element in forwarded.split(","): - # Remove whitespace that may have been introduced when - # appending a new entry - forwarded_element = forwarded_element.strip() - - forwarded_for = forwarded_host = forwarded_proto = "" - forwarded_port = forwarded_by = "" - - for pair in forwarded_element.split(";"): - pair = pair.lower() - - if not pair: - continue - - token, equals, value = pair.partition("=") - - if equals != "=": - raise ValueError('Invalid forwarded-pair missing "="') - - if token.strip() != token: - raise ValueError("Token may not be surrounded by whitespace") - - if value.strip() != value: - raise ValueError("Value may not be surrounded by whitespace") - - if token == "by": - forwarded_by = undquote(value) - - elif token == "for": - forwarded_for = undquote(value) - - elif token == "host": - forwarded_host = undquote(value) - - elif token == "proto": - forwarded_proto = undquote(value) - - else: - logger.warning("Unknown Forwarded token: %s" % token) - - proxies.append( - Forwarded( - forwarded_by, forwarded_for, forwarded_host, forwarded_proto - ) - ) - except Exception as ex: - raise MalformedProxyHeader( - "Forwarded", str(ex), environ["HTTP_FORWARDED"], - ) - - proxies = proxies[-trusted_proxy_count:] - - # Iterate backwards and fill in some values, the oldest entry that - # contains the information we expect is the one we use. We expect - # that intermediate proxies may re-write the host header or proto, - # but the oldest entry is the one that contains the information the - # client expects when generating URL's - # - # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https" - # Forwarded: for=192.0.2.1;host="example.internal:8080" - # - # (After HTTPS header folding) should mean that we use as values: - # - # Host: example.com - # Protocol: https - # Port: 8443 - - for proxy in proxies[::-1]: - client_addr = proxy.for_ or client_addr - forwarded_host = proxy.host or forwarded_host - forwarded_proto = proxy.proto or forwarded_proto - - if forwarded_proto: - forwarded_proto = forwarded_proto.lower() - - if forwarded_proto not in {"http", "https"}: - raise MalformedProxyHeader( - "Forwarded Proto=" if forwarded else "X-Forwarded-Proto", - "unsupported proto value", - forwarded_proto, - ) - - # Set the URL scheme to the proxy provided proto - environ["wsgi.url_scheme"] = forwarded_proto - - if not forwarded_port: - if forwarded_proto == "http": - forwarded_port = "80" - - if forwarded_proto == "https": - forwarded_port = "443" - - if forwarded_host: - if ":" in forwarded_host and forwarded_host[-1] != "]": - host, port = forwarded_host.rsplit(":", 1) - host, port = host.strip(), str(port) - - # We trust the port in the Forwarded Host/X-Forwarded-Host over - # X-Forwarded-Port, or whatever we got from Forwarded - # Proto/X-Forwarded-Proto. - - if forwarded_port != port: - forwarded_port = port - - # We trust the proxy server's forwarded Host - environ["SERVER_NAME"] = host - environ["HTTP_HOST"] = forwarded_host - else: - # We trust the proxy server's forwarded Host - environ["SERVER_NAME"] = forwarded_host - environ["HTTP_HOST"] = forwarded_host - - if forwarded_port: - if forwarded_port not in {"443", "80"}: - environ["HTTP_HOST"] = "{}:{}".format( - forwarded_host, forwarded_port - ) - elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http": - environ["HTTP_HOST"] = "{}:{}".format( - forwarded_host, forwarded_port - ) - elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https": - environ["HTTP_HOST"] = "{}:{}".format( - forwarded_host, forwarded_port - ) - - if forwarded_port: - environ["SERVER_PORT"] = str(forwarded_port) - - if client_addr: - if ":" in client_addr and client_addr[-1] != "]": - addr, port = client_addr.rsplit(":", 1) - environ["REMOTE_ADDR"] = strip_brackets(addr.strip()) - environ["REMOTE_PORT"] = port.strip() - else: - environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip()) - environ["REMOTE_HOST"] = environ["REMOTE_ADDR"] - - return untrusted_headers - - -def strip_brackets(addr): - if addr[0] == "[" and addr[-1] == "]": - return addr[1:-1] - return addr - - -def clear_untrusted_headers( - environ, untrusted_headers, log_warning=False, logger=logger -): - untrusted_headers_removed = [ - header - for header in untrusted_headers - if environ.pop("HTTP_" + header, False) is not False - ] - - if log_warning and untrusted_headers_removed: - untrusted_headers_removed = [ - "-".join(x.capitalize() for x in header.split("_")) - for header in untrusted_headers_removed - ] - logger.warning( - "Removed untrusted headers (%s). Waitress recommends these be " - "removed upstream.", - ", ".join(untrusted_headers_removed), - ) diff --git a/waitress/receiver.py b/waitress/receiver.py deleted file mode 100644 index 5d1568d..0000000 --- a/waitress/receiver.py +++ /dev/null @@ -1,186 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Data Chunk Receiver -""" - -from waitress.utilities import BadRequest, find_double_newline - - -class FixedStreamReceiver(object): - - # See IStreamConsumer - completed = False - error = None - - def __init__(self, cl, buf): - self.remain = cl - self.buf = buf - - def __len__(self): - return self.buf.__len__() - - def received(self, data): - "See IStreamConsumer" - rm = self.remain - - if rm < 1: - self.completed = True # Avoid any chance of spinning - - return 0 - datalen = len(data) - - if rm <= datalen: - self.buf.append(data[:rm]) - self.remain = 0 - self.completed = True - - return rm - else: - self.buf.append(data) - self.remain -= datalen - - return datalen - - def getfile(self): - return self.buf.getfile() - - def getbuf(self): - return self.buf - - -class ChunkedReceiver(object): - - chunk_remainder = 0 - validate_chunk_end = False - control_line = b"" - chunk_end = b"" - all_chunks_received = False - trailer = b"" - completed = False - error = None - - # max_control_line = 1024 - # max_trailer = 65536 - - def __init__(self, buf): - self.buf = buf - - def __len__(self): - return self.buf.__len__() - - def received(self, s): - # Returns the number of bytes consumed. - - if self.completed: - return 0 - orig_size = len(s) - - while s: - rm = self.chunk_remainder - - if rm > 0: - # Receive the remainder of a chunk. - to_write = s[:rm] - self.buf.append(to_write) - written = len(to_write) - s = s[written:] - - self.chunk_remainder -= written - - if self.chunk_remainder == 0: - self.validate_chunk_end = True - elif self.validate_chunk_end: - s = self.chunk_end + s - - pos = s.find(b"\r\n") - - if pos < 0 and len(s) < 2: - self.chunk_end = s - s = b"" - else: - self.chunk_end = b"" - if pos == 0: - # Chop off the terminating CR LF from the chunk - s = s[2:] - else: - self.error = BadRequest("Chunk not properly terminated") - self.all_chunks_received = True - - # Always exit this loop - self.validate_chunk_end = False - elif not self.all_chunks_received: - # Receive a control line. - s = self.control_line + s - pos = s.find(b"\r\n") - - if pos < 0: - # Control line not finished. - self.control_line = s - s = b"" - else: - # Control line finished. - line = s[:pos] - s = s[pos + 2 :] - self.control_line = b"" - line = line.strip() - - if line: - # Begin a new chunk. - semi = line.find(b";") - - if semi >= 0: - # discard extension info. - line = line[:semi] - try: - sz = int(line.strip(), 16) # hexadecimal - except ValueError: # garbage in input - self.error = BadRequest("garbage in chunked encoding input") - sz = 0 - - if sz > 0: - # Start a new chunk. - self.chunk_remainder = sz - else: - # Finished chunks. - self.all_chunks_received = True - # else expect a control line. - else: - # Receive the trailer. - trailer = self.trailer + s - - if trailer.startswith(b"\r\n"): - # No trailer. - self.completed = True - - return orig_size - (len(trailer) - 2) - pos = find_double_newline(trailer) - - if pos < 0: - # Trailer not finished. - self.trailer = trailer - s = b"" - else: - # Finished the trailer. - self.completed = True - self.trailer = trailer[:pos] - - return orig_size - (len(trailer) - pos) - - return orig_size - - def getfile(self): - return self.buf.getfile() - - def getbuf(self): - return self.buf diff --git a/waitress/rfc7230.py b/waitress/rfc7230.py deleted file mode 100644 index cd33c90..0000000 --- a/waitress/rfc7230.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -This contains a bunch of RFC7230 definitions and regular expressions that are -needed to properly parse HTTP messages. -""" - -import re - -from .compat import tobytes - -WS = "[ \t]" -OWS = WS + "{0,}?" -RWS = WS + "{1,}?" -BWS = OWS - -# RFC 7230 Section 3.2.6 "Field Value Components": -# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" -# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" -# / DIGIT / ALPHA -# obs-text = %x80-FF -TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]" -OBS_TEXT = r"\x80-\xff" - -TOKEN = TCHAR + "{1,}" - -# RFC 5234 Appendix B.1 "Core Rules": -# VCHAR = %x21-7E -# ; visible (printing) characters -VCHAR = r"\x21-\x7e" - -# header-field = field-name ":" OWS field-value OWS -# field-name = token -# field-value = *( field-content / obs-fold ) -# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] -# field-vchar = VCHAR / obs-text - -# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 -# changes field-content to: -# -# field-content = field-vchar [ 1*( SP / HTAB / field-vchar ) -# field-vchar ] - -FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]" -# Field content is more greedy than the ABNF, in that it will match the whole value -FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*" -# Which allows the field value here to just see if there is even a value in the first place -FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?" - -HEADER_FIELD = re.compile( - tobytes( - "^(?P" + TOKEN + "):" + OWS + "(?P" + FIELD_VALUE + ")" + OWS + "$" - ) -) diff --git a/waitress/runner.py b/waitress/runner.py deleted file mode 100644 index 2495084..0000000 --- a/waitress/runner.py +++ /dev/null @@ -1,286 +0,0 @@ -############################################################################## -# -# Copyright (c) 2013 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Command line runner. -""" - -from __future__ import print_function, unicode_literals - -import getopt -import os -import os.path -import re -import sys - -from waitress import serve -from waitress.adjustments import Adjustments - -HELP = """\ -Usage: - - {0} [OPTS] MODULE:OBJECT - -Standard options: - - --help - Show this information. - - --call - Call the given object to get the WSGI application. - - --host=ADDR - Hostname or IP address on which to listen, default is '0.0.0.0', - which means "all IP addresses on this host". - - Note: May not be used together with --listen - - --port=PORT - TCP port on which to listen, default is '8080' - - Note: May not be used together with --listen - - --listen=ip:port - Tell waitress to listen on an ip port combination. - - Example: - - --listen=127.0.0.1:8080 - --listen=[::1]:8080 - --listen=*:8080 - - This option may be used multiple times to listen on multiple sockets. - A wildcard for the hostname is also supported and will bind to both - IPv4/IPv6 depending on whether they are enabled or disabled. - - --[no-]ipv4 - Toggle on/off IPv4 support. - - Example: - - --no-ipv4 - - This will disable IPv4 socket support. This affects wildcard matching - when generating the list of sockets. - - --[no-]ipv6 - Toggle on/off IPv6 support. - - Example: - - --no-ipv6 - - This will turn on IPv6 socket support. This affects wildcard matching - when generating a list of sockets. - - --unix-socket=PATH - Path of Unix socket. If a socket path is specified, a Unix domain - socket is made instead of the usual inet domain socket. - - Not available on Windows. - - --unix-socket-perms=PERMS - Octal permissions to use for the Unix domain socket, default is - '600'. - - --url-scheme=STR - Default wsgi.url_scheme value, default is 'http'. - - --url-prefix=STR - The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything - except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be - the value passed minus any trailing slashes you add, and it will cause - the ``PATH_INFO`` of any request which is prefixed with this value to - be stripped of the prefix. Default is the empty string. - - --ident=STR - Server identity used in the 'Server' header in responses. Default - is 'waitress'. - -Tuning options: - - --threads=INT - Number of threads used to process application logic, default is 4. - - --backlog=INT - Connection backlog for the server. Default is 1024. - - --recv-bytes=INT - Number of bytes to request when calling socket.recv(). Default is - 8192. - - --send-bytes=INT - Number of bytes to send to socket.send(). Default is 18000. - Multiples of 9000 should avoid partly-filled TCP packets. - - --outbuf-overflow=INT - A temporary file should be created if the pending output is larger - than this. Default is 1048576 (1MB). - - --outbuf-high-watermark=INT - The app_iter will pause when pending output is larger than this value - and will resume once enough data is written to the socket to fall below - this threshold. Default is 16777216 (16MB). - - --inbuf-overflow=INT - A temporary file should be created if the pending input is larger - than this. Default is 524288 (512KB). - - --connection-limit=INT - Stop creating new channels if too many are already active. - Default is 100. - - --cleanup-interval=INT - Minimum seconds between cleaning up inactive channels. Default - is 30. See '--channel-timeout'. - - --channel-timeout=INT - Maximum number of seconds to leave inactive connections open. - Default is 120. 'Inactive' is defined as 'has received no data - from the client and has sent no data to the client'. - - --[no-]log-socket-errors - Toggle whether premature client disconnect tracebacks ought to be - logged. On by default. - - --max-request-header-size=INT - Maximum size of all request headers combined. Default is 262144 - (256KB). - - --max-request-body-size=INT - Maximum size of request body. Default is 1073741824 (1GB). - - --[no-]expose-tracebacks - Toggle whether to expose tracebacks of unhandled exceptions to the - client. Off by default. - - --asyncore-loop-timeout=INT - The timeout value in seconds passed to asyncore.loop(). Default is 1. - - --asyncore-use-poll - The use_poll argument passed to ``asyncore.loop()``. Helps overcome - open file descriptors limit. Default is False. - -""" - -RUNNER_PATTERN = re.compile( - r""" - ^ - (?P - [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* - ) - : - (?P - [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* - ) - $ - """, - re.I | re.X, -) - - -def match(obj_name): - matches = RUNNER_PATTERN.match(obj_name) - if not matches: - raise ValueError("Malformed application '{0}'".format(obj_name)) - return matches.group("module"), matches.group("object") - - -def resolve(module_name, object_name): - """Resolve a named object in a module.""" - # We cast each segments due to an issue that has been found to manifest - # in Python 2.6.6, but not 2.6.8, and may affect other revisions of Python - # 2.6 and 2.7, whereby ``__import__`` chokes if the list passed in the - # ``fromlist`` argument are unicode strings rather than 8-bit strings. - # The error triggered is "TypeError: Item in ``fromlist '' not a string". - # My guess is that this was fixed by checking against ``basestring`` - # rather than ``str`` sometime between the release of 2.6.6 and 2.6.8, - # but I've yet to go over the commits. I know, however, that the NEWS - # file makes no mention of such a change to the behaviour of - # ``__import__``. - segments = [str(segment) for segment in object_name.split(".")] - obj = __import__(module_name, fromlist=segments[:1]) - for segment in segments: - obj = getattr(obj, segment) - return obj - - -def show_help(stream, name, error=None): # pragma: no cover - if error is not None: - print("Error: {0}\n".format(error), file=stream) - print(HELP.format(name), file=stream) - - -def show_exception(stream): - exc_type, exc_value = sys.exc_info()[:2] - args = getattr(exc_value, "args", None) - print( - ("There was an exception ({0}) importing your module.\n").format( - exc_type.__name__, - ), - file=stream, - ) - if args: - print("It had these arguments: ", file=stream) - for idx, arg in enumerate(args, start=1): - print("{0}. {1}\n".format(idx, arg), file=stream) - else: - print("It had no arguments.", file=stream) - - -def run(argv=sys.argv, _serve=serve): - """Command line runner.""" - name = os.path.basename(argv[0]) - - try: - kw, args = Adjustments.parse_args(argv[1:]) - except getopt.GetoptError as exc: - show_help(sys.stderr, name, str(exc)) - return 1 - - if kw["help"]: - show_help(sys.stdout, name) - return 0 - - if len(args) != 1: - show_help(sys.stderr, name, "Specify one application only") - return 1 - - try: - module, obj_name = match(args[0]) - except ValueError as exc: - show_help(sys.stderr, name, str(exc)) - show_exception(sys.stderr) - return 1 - - # Add the current directory onto sys.path - sys.path.append(os.getcwd()) - - # Get the WSGI function. - try: - app = resolve(module, obj_name) - except ImportError: - show_help(sys.stderr, name, "Bad module '{0}'".format(module)) - show_exception(sys.stderr) - return 1 - except AttributeError: - show_help(sys.stderr, name, "Bad object name '{0}'".format(obj_name)) - show_exception(sys.stderr) - return 1 - if kw["call"]: - app = app() - - # These arguments are specific to the runner, not waitress itself. - del kw["call"], kw["help"] - - _serve(app, **kw) - return 0 diff --git a/waitress/server.py b/waitress/server.py deleted file mode 100644 index ae56699..0000000 --- a/waitress/server.py +++ /dev/null @@ -1,436 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import os -import os.path -import socket -import time - -from waitress import trigger -from waitress.adjustments import Adjustments -from waitress.channel import HTTPChannel -from waitress.task import ThreadedTaskDispatcher -from waitress.utilities import cleanup_unix_socket - -from waitress.compat import ( - IPPROTO_IPV6, - IPV6_V6ONLY, -) -from . import wasyncore -from .proxy_headers import proxy_headers_middleware - - -def create_server( - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - _dispatcher=None, # test shim - **kw # adjustments -): - """ - if __name__ == '__main__': - server = create_server(app) - server.run() - """ - if application is None: - raise ValueError( - 'The "app" passed to ``create_server`` was ``None``. You forgot ' - "to return a WSGI app within your application." - ) - adj = Adjustments(**kw) - - if map is None: # pragma: nocover - map = {} - - dispatcher = _dispatcher - if dispatcher is None: - dispatcher = ThreadedTaskDispatcher() - dispatcher.set_thread_count(adj.threads) - - if adj.unix_socket and hasattr(socket, "AF_UNIX"): - sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) - return UnixWSGIServer( - application, - map, - _start, - _sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo, - ) - - effective_listen = [] - last_serv = None - if not adj.sockets: - for sockinfo in adj.listen: - # When TcpWSGIServer is called, it registers itself in the map. This - # side-effect is all we need it for, so we don't store a reference to - # or return it to the user. - last_serv = TcpWSGIServer( - application, - map, - _start, - _sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo, - ) - effective_listen.append( - (last_serv.effective_host, last_serv.effective_port) - ) - - for sock in adj.sockets: - sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) - if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: - last_serv = TcpWSGIServer( - application, - map, - _start, - sock, - dispatcher=dispatcher, - adj=adj, - bind_socket=False, - sockinfo=sockinfo, - ) - effective_listen.append( - (last_serv.effective_host, last_serv.effective_port) - ) - elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: - last_serv = UnixWSGIServer( - application, - map, - _start, - sock, - dispatcher=dispatcher, - adj=adj, - bind_socket=False, - sockinfo=sockinfo, - ) - effective_listen.append( - (last_serv.effective_host, last_serv.effective_port) - ) - - # We are running a single server, so we can just return the last server, - # saves us from having to create one more object - if len(effective_listen) == 1: - # In this case we have no need to use a MultiSocketServer - return last_serv - - # Return a class that has a utility function to print out the sockets it's - # listening on, and has a .run() function. All of the TcpWSGIServers - # registered themselves in the map above. - return MultiSocketServer(map, adj, effective_listen, dispatcher) - - -# This class is only ever used if we have multiple listen sockets. It allows -# the serve() API to call .run() which starts the wasyncore loop, and catches -# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. -class MultiSocketServer(object): - asyncore = wasyncore # test shim - - def __init__( - self, map=None, adj=None, effective_listen=None, dispatcher=None, - ): - self.adj = adj - self.map = map - self.effective_listen = effective_listen - self.task_dispatcher = dispatcher - - def print_listen(self, format_str): # pragma: nocover - for l in self.effective_listen: - l = list(l) - - if ":" in l[0]: - l[0] = "[{}]".format(l[0]) - - print(format_str.format(*l)) - - def run(self): - try: - self.asyncore.loop( - timeout=self.adj.asyncore_loop_timeout, - map=self.map, - use_poll=self.adj.asyncore_use_poll, - ) - except (SystemExit, KeyboardInterrupt): - self.close() - - def close(self): - self.task_dispatcher.shutdown() - wasyncore.close_all(self.map) - - -class BaseWSGIServer(wasyncore.dispatcher, object): - - channel_class = HTTPChannel - next_channel_cleanup = 0 - socketmod = socket # test shim - asyncore = wasyncore # test shim - - def __init__( - self, - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - dispatcher=None, # dispatcher - adj=None, # adjustments - sockinfo=None, # opaque object - bind_socket=True, - **kw - ): - if adj is None: - adj = Adjustments(**kw) - - if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: - # wrap the application to deal with proxy headers - # we wrap it here because webtest subclasses the TcpWSGIServer - # directly and thus doesn't run any code that's in create_server - application = proxy_headers_middleware( - application, - trusted_proxy=adj.trusted_proxy, - trusted_proxy_count=adj.trusted_proxy_count, - trusted_proxy_headers=adj.trusted_proxy_headers, - clear_untrusted=adj.clear_untrusted_proxy_headers, - log_untrusted=adj.log_untrusted_proxy_headers, - logger=self.logger, - ) - - if map is None: - # use a nonglobal socket map by default to hopefully prevent - # conflicts with apps and libs that use the wasyncore global socket - # map ala https://github.com/Pylons/waitress/issues/63 - map = {} - if sockinfo is None: - sockinfo = adj.listen[0] - - self.sockinfo = sockinfo - self.family = sockinfo[0] - self.socktype = sockinfo[1] - self.application = application - self.adj = adj - self.trigger = trigger.trigger(map) - if dispatcher is None: - dispatcher = ThreadedTaskDispatcher() - dispatcher.set_thread_count(self.adj.threads) - - self.task_dispatcher = dispatcher - self.asyncore.dispatcher.__init__(self, _sock, map=map) - if _sock is None: - self.create_socket(self.family, self.socktype) - if self.family == socket.AF_INET6: # pragma: nocover - self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) - - self.set_reuse_addr() - - if bind_socket: - self.bind_server_socket() - - self.effective_host, self.effective_port = self.getsockname() - self.server_name = self.get_server_name(self.effective_host) - self.active_channels = {} - if _start: - self.accept_connections() - - def bind_server_socket(self): - raise NotImplementedError # pragma: no cover - - def get_server_name(self, ip): - """Given an IP or hostname, try to determine the server name.""" - - if not ip: - raise ValueError("Requires an IP to get the server name") - - server_name = str(ip) - - # If we are bound to all IP's, just return the current hostname, only - # fall-back to "localhost" if we fail to get the hostname - if server_name == "0.0.0.0" or server_name == "::": - try: - return str(self.socketmod.gethostname()) - except (socket.error, UnicodeDecodeError): # pragma: no cover - # We also deal with UnicodeDecodeError in case of Windows with - # non-ascii hostname - return "localhost" - - # Now let's try and convert the IP address to a proper hostname - try: - server_name = self.socketmod.gethostbyaddr(server_name)[0] - except (socket.error, UnicodeDecodeError): # pragma: no cover - # We also deal with UnicodeDecodeError in case of Windows with - # non-ascii hostname - pass - - # If it contains an IPv6 literal, make sure to surround it with - # brackets - if ":" in server_name and "[" not in server_name: - server_name = "[{}]".format(server_name) - - return server_name - - def getsockname(self): - raise NotImplementedError # pragma: no cover - - def accept_connections(self): - self.accepting = True - self.socket.listen(self.adj.backlog) # Get around asyncore NT limit - - def add_task(self, task): - self.task_dispatcher.add_task(task) - - def readable(self): - now = time.time() - if now >= self.next_channel_cleanup: - self.next_channel_cleanup = now + self.adj.cleanup_interval - self.maintenance(now) - return self.accepting and len(self._map) < self.adj.connection_limit - - def writable(self): - return False - - def handle_read(self): - pass - - def handle_connect(self): - pass - - def handle_accept(self): - try: - v = self.accept() - if v is None: - return - conn, addr = v - except socket.error: - # Linux: On rare occasions we get a bogus socket back from - # accept. socketmodule.c:makesockaddr complains that the - # address family is unknown. We don't want the whole server - # to shut down because of this. - if self.adj.log_socket_errors: - self.logger.warning("server accept() threw an exception", exc_info=True) - return - self.set_socket_options(conn) - addr = self.fix_addr(addr) - self.channel_class(self, conn, addr, self.adj, map=self._map) - - def run(self): - try: - self.asyncore.loop( - timeout=self.adj.asyncore_loop_timeout, - map=self._map, - use_poll=self.adj.asyncore_use_poll, - ) - except (SystemExit, KeyboardInterrupt): - self.task_dispatcher.shutdown() - - def pull_trigger(self): - self.trigger.pull_trigger() - - def set_socket_options(self, conn): - pass - - def fix_addr(self, addr): - return addr - - def maintenance(self, now): - """ - Closes channels that have not had any activity in a while. - - The timeout is configured through adj.channel_timeout (seconds). - """ - cutoff = now - self.adj.channel_timeout - for channel in self.active_channels.values(): - if (not channel.requests) and channel.last_activity < cutoff: - channel.will_close = True - - def print_listen(self, format_str): # pragma: nocover - print(format_str.format(self.effective_host, self.effective_port)) - - def close(self): - self.trigger.close() - return wasyncore.dispatcher.close(self) - - -class TcpWSGIServer(BaseWSGIServer): - def bind_server_socket(self): - (_, _, _, sockaddr) = self.sockinfo - self.bind(sockaddr) - - def getsockname(self): - try: - return self.socketmod.getnameinfo( - self.socket.getsockname(), self.socketmod.NI_NUMERICSERV - ) - except: # pragma: no cover - # This only happens on Linux because a DNS issue is considered a - # temporary failure that will raise (even when NI_NAMEREQD is not - # set). Instead we try again, but this time we just ask for the - # numerichost and the numericserv (port) and return those. It is - # better than nothing. - return self.socketmod.getnameinfo( - self.socket.getsockname(), - self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, - ) - - def set_socket_options(self, conn): - for (level, optname, value) in self.adj.socket_options: - conn.setsockopt(level, optname, value) - - -if hasattr(socket, "AF_UNIX"): - - class UnixWSGIServer(BaseWSGIServer): - def __init__( - self, - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - dispatcher=None, # dispatcher - adj=None, # adjustments - sockinfo=None, # opaque object - **kw - ): - if sockinfo is None: - sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) - - super(UnixWSGIServer, self).__init__( - application, - map=map, - _start=_start, - _sock=_sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo, - **kw - ) - - def bind_server_socket(self): - cleanup_unix_socket(self.adj.unix_socket) - self.bind(self.adj.unix_socket) - if os.path.exists(self.adj.unix_socket): - os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) - - def getsockname(self): - return ("unix", self.socket.getsockname()) - - def fix_addr(self, addr): - return ("localhost", None) - - def get_server_name(self, ip): - return "localhost" - - -# Compatibility alias. -WSGIServer = TcpWSGIServer diff --git a/waitress/task.py b/waitress/task.py deleted file mode 100644 index 8e7ab18..0000000 --- a/waitress/task.py +++ /dev/null @@ -1,570 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import socket -import sys -import threading -import time -from collections import deque - -from .buffers import ReadOnlyFileBasedBuffer -from .compat import reraise, tobytes -from .utilities import build_http_date, logger, queue_logger - -rename_headers = { # or keep them without the HTTP_ prefix added - "CONTENT_LENGTH": "CONTENT_LENGTH", - "CONTENT_TYPE": "CONTENT_TYPE", -} - -hop_by_hop = frozenset( - ( - "connection", - "keep-alive", - "proxy-authenticate", - "proxy-authorization", - "te", - "trailers", - "transfer-encoding", - "upgrade", - ) -) - - -class ThreadedTaskDispatcher(object): - """A Task Dispatcher that creates a thread for each task. - """ - - stop_count = 0 # Number of threads that will stop soon. - active_count = 0 # Number of currently active threads - logger = logger - queue_logger = queue_logger - - def __init__(self): - self.threads = set() - self.queue = deque() - self.lock = threading.Lock() - self.queue_cv = threading.Condition(self.lock) - self.thread_exit_cv = threading.Condition(self.lock) - - def start_new_thread(self, target, args): - t = threading.Thread(target=target, name="waitress", args=args) - t.daemon = True - t.start() - - def handler_thread(self, thread_no): - while True: - with self.lock: - while not self.queue and self.stop_count == 0: - # Mark ourselves as idle before waiting to be - # woken up, then we will once again be active - self.active_count -= 1 - self.queue_cv.wait() - self.active_count += 1 - - if self.stop_count > 0: - self.active_count -= 1 - self.stop_count -= 1 - self.threads.discard(thread_no) - self.thread_exit_cv.notify() - break - - task = self.queue.popleft() - try: - task.service() - except BaseException: - self.logger.exception("Exception when servicing %r", task) - - def set_thread_count(self, count): - with self.lock: - threads = self.threads - thread_no = 0 - running = len(threads) - self.stop_count - while running < count: - # Start threads. - while thread_no in threads: - thread_no = thread_no + 1 - threads.add(thread_no) - running += 1 - self.start_new_thread(self.handler_thread, (thread_no,)) - self.active_count += 1 - thread_no = thread_no + 1 - if running > count: - # Stop threads. - self.stop_count += running - count - self.queue_cv.notify_all() - - def add_task(self, task): - with self.lock: - self.queue.append(task) - self.queue_cv.notify() - queue_size = len(self.queue) - idle_threads = len(self.threads) - self.stop_count - self.active_count - if queue_size > idle_threads: - self.queue_logger.warning( - "Task queue depth is %d", queue_size - idle_threads - ) - - def shutdown(self, cancel_pending=True, timeout=5): - self.set_thread_count(0) - # Ensure the threads shut down. - threads = self.threads - expiration = time.time() + timeout - with self.lock: - while threads: - if time.time() >= expiration: - self.logger.warning("%d thread(s) still running", len(threads)) - break - self.thread_exit_cv.wait(0.1) - if cancel_pending: - # Cancel remaining tasks. - queue = self.queue - if len(queue) > 0: - self.logger.warning("Canceling %d pending task(s)", len(queue)) - while queue: - task = queue.popleft() - task.cancel() - self.queue_cv.notify_all() - return True - return False - - -class Task(object): - close_on_finish = False - status = "200 OK" - wrote_header = False - start_time = 0 - content_length = None - content_bytes_written = 0 - logged_write_excess = False - logged_write_no_body = False - complete = False - chunked_response = False - logger = logger - - def __init__(self, channel, request): - self.channel = channel - self.request = request - self.response_headers = [] - version = request.version - if version not in ("1.0", "1.1"): - # fall back to a version we support. - version = "1.0" - self.version = version - - def service(self): - try: - try: - self.start() - self.execute() - self.finish() - except socket.error: - self.close_on_finish = True - if self.channel.adj.log_socket_errors: - raise - finally: - pass - - @property - def has_body(self): - return not ( - self.status.startswith("1") - or self.status.startswith("204") - or self.status.startswith("304") - ) - - def build_response_header(self): - version = self.version - # Figure out whether the connection should be closed. - connection = self.request.headers.get("CONNECTION", "").lower() - response_headers = [] - content_length_header = None - date_header = None - server_header = None - connection_close_header = None - - for (headername, headerval) in self.response_headers: - headername = "-".join([x.capitalize() for x in headername.split("-")]) - - if headername == "Content-Length": - if self.has_body: - content_length_header = headerval - else: - continue # pragma: no cover - - if headername == "Date": - date_header = headerval - - if headername == "Server": - server_header = headerval - - if headername == "Connection": - connection_close_header = headerval.lower() - # replace with properly capitalized version - response_headers.append((headername, headerval)) - - if ( - content_length_header is None - and self.content_length is not None - and self.has_body - ): - content_length_header = str(self.content_length) - response_headers.append(("Content-Length", content_length_header)) - - def close_on_finish(): - if connection_close_header is None: - response_headers.append(("Connection", "close")) - self.close_on_finish = True - - if version == "1.0": - if connection == "keep-alive": - if not content_length_header: - close_on_finish() - else: - response_headers.append(("Connection", "Keep-Alive")) - else: - close_on_finish() - - elif version == "1.1": - if connection == "close": - close_on_finish() - - if not content_length_header: - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx, 204 or 304. - - if self.has_body: - response_headers.append(("Transfer-Encoding", "chunked")) - self.chunked_response = True - - if not self.close_on_finish: - close_on_finish() - - # under HTTP 1.1 keep-alive is default, no need to set the header - else: - raise AssertionError("neither HTTP/1.0 or HTTP/1.1") - - # Set the Server and Date field, if not yet specified. This is needed - # if the server is used as a proxy. - ident = self.channel.server.adj.ident - - if not server_header: - if ident: - response_headers.append(("Server", ident)) - else: - response_headers.append(("Via", ident or "waitress")) - - if not date_header: - response_headers.append(("Date", build_http_date(self.start_time))) - - self.response_headers = response_headers - - first_line = "HTTP/%s %s" % (self.version, self.status) - # NB: sorting headers needs to preserve same-named-header order - # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; - # rely on stable sort to keep relative position of same-named headers - next_lines = [ - "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) - ] - lines = [first_line] + next_lines - res = "%s\r\n\r\n" % "\r\n".join(lines) - - return tobytes(res) - - def remove_content_length_header(self): - response_headers = [] - - for header_name, header_value in self.response_headers: - if header_name.lower() == "content-length": - continue # pragma: nocover - response_headers.append((header_name, header_value)) - - self.response_headers = response_headers - - def start(self): - self.start_time = time.time() - - def finish(self): - if not self.wrote_header: - self.write(b"") - if self.chunked_response: - # not self.write, it will chunk it! - self.channel.write_soon(b"0\r\n\r\n") - - def write(self, data): - if not self.complete: - raise RuntimeError("start_response was not called before body written") - channel = self.channel - if not self.wrote_header: - rh = self.build_response_header() - channel.write_soon(rh) - self.wrote_header = True - - if data and self.has_body: - towrite = data - cl = self.content_length - if self.chunked_response: - # use chunked encoding response - towrite = tobytes(hex(len(data))[2:].upper()) + b"\r\n" - towrite += data + b"\r\n" - elif cl is not None: - towrite = data[: cl - self.content_bytes_written] - self.content_bytes_written += len(towrite) - if towrite != data and not self.logged_write_excess: - self.logger.warning( - "application-written content exceeded the number of " - "bytes specified by Content-Length header (%s)" % cl - ) - self.logged_write_excess = True - if towrite: - channel.write_soon(towrite) - elif data: - # Cheat, and tell the application we have written all of the bytes, - # even though the response shouldn't have a body and we are - # ignoring it entirely. - self.content_bytes_written += len(data) - - if not self.logged_write_no_body: - self.logger.warning( - "application-written content was ignored due to HTTP " - "response that may not contain a message-body: (%s)" % self.status - ) - self.logged_write_no_body = True - - -class ErrorTask(Task): - """ An error task produces an error response - """ - - complete = True - - def execute(self): - e = self.request.error - status, headers, body = e.to_response() - self.status = status - self.response_headers.extend(headers) - # We need to explicitly tell the remote client we are closing the - # connection, because self.close_on_finish is set, and we are going to - # slam the door in the clients face. - self.response_headers.append(("Connection", "close")) - self.close_on_finish = True - self.content_length = len(body) - self.write(tobytes(body)) - - -class WSGITask(Task): - """A WSGI task produces a response from a WSGI application. - """ - - environ = None - - def execute(self): - environ = self.get_environment() - - def start_response(status, headers, exc_info=None): - if self.complete and not exc_info: - raise AssertionError( - "start_response called a second time without providing exc_info." - ) - if exc_info: - try: - if self.wrote_header: - # higher levels will catch and handle raised exception: - # 1. "service" method in task.py - # 2. "service" method in channel.py - # 3. "handler_thread" method in task.py - reraise(exc_info[0], exc_info[1], exc_info[2]) - else: - # As per WSGI spec existing headers must be cleared - self.response_headers = [] - finally: - exc_info = None - - self.complete = True - - if not status.__class__ is str: - raise AssertionError("status %s is not a string" % status) - if "\n" in status or "\r" in status: - raise ValueError( - "carriage return/line feed character present in status" - ) - - self.status = status - - # Prepare the headers for output - for k, v in headers: - if not k.__class__ is str: - raise AssertionError( - "Header name %r is not a string in %r" % (k, (k, v)) - ) - if not v.__class__ is str: - raise AssertionError( - "Header value %r is not a string in %r" % (v, (k, v)) - ) - - if "\n" in v or "\r" in v: - raise ValueError( - "carriage return/line feed character present in header value" - ) - if "\n" in k or "\r" in k: - raise ValueError( - "carriage return/line feed character present in header name" - ) - - kl = k.lower() - if kl == "content-length": - self.content_length = int(v) - elif kl in hop_by_hop: - raise AssertionError( - '%s is a "hop-by-hop" header; it cannot be used by ' - "a WSGI application (see PEP 3333)" % k - ) - - self.response_headers.extend(headers) - - # Return a method used to write the response data. - return self.write - - # Call the application to handle the request and write a response - app_iter = self.channel.server.application(environ, start_response) - - can_close_app_iter = True - try: - if app_iter.__class__ is ReadOnlyFileBasedBuffer: - cl = self.content_length - size = app_iter.prepare(cl) - if size: - if cl != size: - if cl is not None: - self.remove_content_length_header() - self.content_length = size - self.write(b"") # generate headers - # if the write_soon below succeeds then the channel will - # take over closing the underlying file via the channel's - # _flush_some or handle_close so we intentionally avoid - # calling close in the finally block - self.channel.write_soon(app_iter) - can_close_app_iter = False - return - - first_chunk_len = None - for chunk in app_iter: - if first_chunk_len is None: - first_chunk_len = len(chunk) - # Set a Content-Length header if one is not supplied. - # start_response may not have been called until first - # iteration as per PEP, so we must reinterrogate - # self.content_length here - if self.content_length is None: - app_iter_len = None - if hasattr(app_iter, "__len__"): - app_iter_len = len(app_iter) - if app_iter_len == 1: - self.content_length = first_chunk_len - # transmit headers only after first iteration of the iterable - # that returns a non-empty bytestring (PEP 3333) - if chunk: - self.write(chunk) - - cl = self.content_length - if cl is not None: - if self.content_bytes_written != cl: - # close the connection so the client isn't sitting around - # waiting for more data when there are too few bytes - # to service content-length - self.close_on_finish = True - if self.request.command != "HEAD": - self.logger.warning( - "application returned too few bytes (%s) " - "for specified Content-Length (%s) via app_iter" - % (self.content_bytes_written, cl), - ) - finally: - if can_close_app_iter and hasattr(app_iter, "close"): - app_iter.close() - - def get_environment(self): - """Returns a WSGI environment.""" - environ = self.environ - if environ is not None: - # Return the cached copy. - return environ - - request = self.request - path = request.path - channel = self.channel - server = channel.server - url_prefix = server.adj.url_prefix - - if path.startswith("/"): - # strip extra slashes at the beginning of a path that starts - # with any number of slashes - path = "/" + path.lstrip("/") - - if url_prefix: - # NB: url_prefix is guaranteed by the configuration machinery to - # be either the empty string or a string that starts with a single - # slash and ends without any slashes - if path == url_prefix: - # if the path is the same as the url prefix, the SCRIPT_NAME - # should be the url_prefix and PATH_INFO should be empty - path = "" - else: - # if the path starts with the url prefix plus a slash, - # the SCRIPT_NAME should be the url_prefix and PATH_INFO should - # the value of path from the slash until its end - url_prefix_with_trailing_slash = url_prefix + "/" - if path.startswith(url_prefix_with_trailing_slash): - path = path[len(url_prefix) :] - - environ = { - "REMOTE_ADDR": channel.addr[0], - # Nah, we aren't actually going to look up the reverse DNS for - # REMOTE_ADDR, but we will happily set this environment variable - # for the WSGI application. Spec says we can just set this to - # REMOTE_ADDR, so we do. - "REMOTE_HOST": channel.addr[0], - # try and set the REMOTE_PORT to something useful, but maybe None - "REMOTE_PORT": str(channel.addr[1]), - "REQUEST_METHOD": request.command.upper(), - "SERVER_PORT": str(server.effective_port), - "SERVER_NAME": server.server_name, - "SERVER_SOFTWARE": server.adj.ident, - "SERVER_PROTOCOL": "HTTP/%s" % self.version, - "SCRIPT_NAME": url_prefix, - "PATH_INFO": path, - "QUERY_STRING": request.query, - "wsgi.url_scheme": request.url_scheme, - # the following environment variables are required by the WSGI spec - "wsgi.version": (1, 0), - # apps should use the logging module - "wsgi.errors": sys.stderr, - "wsgi.multithread": True, - "wsgi.multiprocess": False, - "wsgi.run_once": False, - "wsgi.input": request.get_body_stream(), - "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, - "wsgi.input_terminated": True, # wsgi.input is EOF terminated - } - - for key, value in dict(request.headers).items(): - value = value.strip() - mykey = rename_headers.get(key, None) - if mykey is None: - mykey = "HTTP_" + key - if mykey not in environ: - environ[mykey] = value - - # cache the environ for this request - self.environ = environ - return environ diff --git a/waitress/trigger.py b/waitress/trigger.py deleted file mode 100644 index 6a57c12..0000000 --- a/waitress/trigger.py +++ /dev/null @@ -1,203 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001-2005 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE -# -############################################################################## - -import os -import socket -import errno -import threading - -from . import wasyncore - -# Wake up a call to select() running in the main thread. -# -# This is useful in a context where you are using Medusa's I/O -# subsystem to deliver data, but the data is generated by another -# thread. Normally, if Medusa is in the middle of a call to -# select(), new output data generated by another thread will have -# to sit until the call to select() either times out or returns. -# If the trigger is 'pulled' by another thread, it should immediately -# generate a READ event on the trigger object, which will force the -# select() invocation to return. -# -# A common use for this facility: letting Medusa manage I/O for a -# large number of connections; but routing each request through a -# thread chosen from a fixed-size thread pool. When a thread is -# acquired, a transaction is performed, but output data is -# accumulated into buffers that will be emptied more efficiently -# by Medusa. [picture a server that can process database queries -# rapidly, but doesn't want to tie up threads waiting to send data -# to low-bandwidth connections] -# -# The other major feature provided by this class is the ability to -# move work back into the main thread: if you call pull_trigger() -# with a thunk argument, when select() wakes up and receives the -# event it will call your thunk from within that thread. The main -# purpose of this is to remove the need to wrap thread locks around -# Medusa's data structures, which normally do not need them. [To see -# why this is true, imagine this scenario: A thread tries to push some -# new data onto a channel's outgoing data queue at the same time that -# the main thread is trying to remove some] - - -class _triggerbase(object): - """OS-independent base class for OS-dependent trigger class.""" - - kind = None # subclass must set to "pipe" or "loopback"; used by repr - - def __init__(self): - self._closed = False - - # `lock` protects the `thunks` list from being traversed and - # appended to simultaneously. - self.lock = threading.Lock() - - # List of no-argument callbacks to invoke when the trigger is - # pulled. These run in the thread running the wasyncore mainloop, - # regardless of which thread pulls the trigger. - self.thunks = [] - - def readable(self): - return True - - def writable(self): - return False - - def handle_connect(self): - pass - - def handle_close(self): - self.close() - - # Override the wasyncore close() method, because it doesn't know about - # (so can't close) all the gimmicks we have open. Subclass must - # supply a _close() method to do platform-specific closing work. _close() - # will be called iff we're not already closed. - def close(self): - if not self._closed: - self._closed = True - self.del_channel() - self._close() # subclass does OS-specific stuff - - def pull_trigger(self, thunk=None): - if thunk: - with self.lock: - self.thunks.append(thunk) - self._physical_pull() - - def handle_read(self): - try: - self.recv(8192) - except (OSError, socket.error): - return - with self.lock: - for thunk in self.thunks: - try: - thunk() - except: - nil, t, v, tbinfo = wasyncore.compact_traceback() - self.log_info( - "exception in trigger thunk: (%s:%s %s)" % (t, v, tbinfo) - ) - self.thunks = [] - - -if os.name == "posix": - - class trigger(_triggerbase, wasyncore.file_dispatcher): - kind = "pipe" - - def __init__(self, map): - _triggerbase.__init__(self) - r, self.trigger = self._fds = os.pipe() - wasyncore.file_dispatcher.__init__(self, r, map=map) - - def _close(self): - for fd in self._fds: - os.close(fd) - self._fds = [] - wasyncore.file_dispatcher.close(self) - - def _physical_pull(self): - os.write(self.trigger, b"x") - - -else: # pragma: no cover - # Windows version; uses just sockets, because a pipe isn't select'able - # on Windows. - - class trigger(_triggerbase, wasyncore.dispatcher): - kind = "loopback" - - def __init__(self, map): - _triggerbase.__init__(self) - - # Get a pair of connected sockets. The trigger is the 'w' - # end of the pair, which is connected to 'r'. 'r' is put - # in the wasyncore socket map. "pulling the trigger" then - # means writing something on w, which will wake up r. - - w = socket.socket() - # Disable buffering -- pulling the trigger sends 1 byte, - # and we want that sent immediately, to wake up wasyncore's - # select() ASAP. - w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - count = 0 - while True: - count += 1 - # Bind to a local port; for efficiency, let the OS pick - # a free port for us. - # Unfortunately, stress tests showed that we may not - # be able to connect to that port ("Address already in - # use") despite that the OS picked it. This appears - # to be a race bug in the Windows socket implementation. - # So we loop until a connect() succeeds (almost always - # on the first try). See the long thread at - # http://mail.zope.org/pipermail/zope/2005-July/160433.html - # for hideous details. - a = socket.socket() - a.bind(("127.0.0.1", 0)) - connect_address = a.getsockname() # assigned (host, port) pair - a.listen(1) - try: - w.connect(connect_address) - break # success - except socket.error as detail: - if detail[0] != errno.WSAEADDRINUSE: - # "Address already in use" is the only error - # I've seen on two WinXP Pro SP2 boxes, under - # Pythons 2.3.5 and 2.4.1. - raise - # (10048, 'Address already in use') - # assert count <= 2 # never triggered in Tim's tests - if count >= 10: # I've never seen it go above 2 - a.close() - w.close() - raise RuntimeError("Cannot bind trigger!") - # Close `a` and try again. Note: I originally put a short - # sleep() here, but it didn't appear to help or hurt. - a.close() - - r, addr = a.accept() # r becomes wasyncore's (self.)socket - a.close() - self.trigger = w - wasyncore.dispatcher.__init__(self, r, map=map) - - def _close(self): - # self.socket is r, and self.trigger is w, from __init__ - self.socket.close() - self.trigger.close() - - def _physical_pull(self): - self.trigger.send(b"x") diff --git a/waitress/utilities.py b/waitress/utilities.py deleted file mode 100644 index 556bed2..0000000 --- a/waitress/utilities.py +++ /dev/null @@ -1,320 +0,0 @@ -############################################################################## -# -# Copyright (c) 2004 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Utility functions -""" - -import calendar -import errno -import logging -import os -import re -import stat -import time - -from .rfc7230 import OBS_TEXT, VCHAR - -logger = logging.getLogger("waitress") -queue_logger = logging.getLogger("waitress.queue") - - -def find_double_newline(s): - """Returns the position just after a double newline in the given string.""" - pos = s.find(b"\r\n\r\n") - - if pos >= 0: - pos += 4 - - return pos - - -def concat(*args): - return "".join(args) - - -def join(seq, field=" "): - return field.join(seq) - - -def group(s): - return "(" + s + ")" - - -short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] -long_days = [ - "sunday", - "monday", - "tuesday", - "wednesday", - "thursday", - "friday", - "saturday", -] - -short_day_reg = group(join(short_days, "|")) -long_day_reg = group(join(long_days, "|")) - -daymap = {} - -for i in range(7): - daymap[short_days[i]] = i - daymap[long_days[i]] = i - -hms_reg = join(3 * [group("[0-9][0-9]")], ":") - -months = [ - "jan", - "feb", - "mar", - "apr", - "may", - "jun", - "jul", - "aug", - "sep", - "oct", - "nov", - "dec", -] - -monmap = {} - -for i in range(12): - monmap[months[i]] = i + 1 - -months_reg = group(join(months, "|")) - -# From draft-ietf-http-v11-spec-07.txt/3.3.1 -# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 -# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 -# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format - -# rfc822 format -rfc822_date = join( - [ - concat(short_day_reg, ","), # day - group("[0-9][0-9]?"), # date - months_reg, # month - group("[0-9]+"), # year - hms_reg, # hour minute second - "gmt", - ], - " ", -) - -rfc822_reg = re.compile(rfc822_date) - - -def unpack_rfc822(m): - g = m.group - - return ( - int(g(4)), # year - monmap[g(3)], # month - int(g(2)), # day - int(g(5)), # hour - int(g(6)), # minute - int(g(7)), # second - 0, - 0, - 0, - ) - - -# rfc850 format -rfc850_date = join( - [ - concat(long_day_reg, ","), - join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"), - hms_reg, - "gmt", - ], - " ", -) - -rfc850_reg = re.compile(rfc850_date) -# they actually unpack the same way -def unpack_rfc850(m): - g = m.group - yr = g(4) - - if len(yr) == 2: - yr = "19" + yr - - return ( - int(yr), # year - monmap[g(3)], # month - int(g(2)), # day - int(g(5)), # hour - int(g(6)), # minute - int(g(7)), # second - 0, - 0, - 0, - ) - - -# parsdate.parsedate - ~700/sec. -# parse_http_date - ~1333/sec. - -weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] -monthname = [ - None, - "Jan", - "Feb", - "Mar", - "Apr", - "May", - "Jun", - "Jul", - "Aug", - "Sep", - "Oct", - "Nov", - "Dec", -] - - -def build_http_date(when): - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) - - return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - weekdayname[wd], - day, - monthname[month], - year, - hh, - mm, - ss, - ) - - -def parse_http_date(d): - d = d.lower() - m = rfc850_reg.match(d) - - if m and m.end() == len(d): - retval = int(calendar.timegm(unpack_rfc850(m))) - else: - m = rfc822_reg.match(d) - - if m and m.end() == len(d): - retval = int(calendar.timegm(unpack_rfc822(m))) - else: - return 0 - - return retval - - -# RFC 5234 Appendix B.1 "Core Rules": -# VCHAR = %x21-7E -# ; visible (printing) characters -vchar_re = VCHAR - -# RFC 7230 Section 3.2.6 "Field Value Components": -# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE -# qdtext = HTAB / SP /%x21 / %x23-5B / %x5D-7E / obs-text -# obs-text = %x80-FF -# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) -obs_text_re = OBS_TEXT - -# The '\\' between \x5b and \x5d is needed to escape \x5d (']') -qdtext_re = "[\t \x21\x23-\x5b\\\x5d-\x7e" + obs_text_re + "]" - -quoted_pair_re = r"\\" + "([\t " + vchar_re + obs_text_re + "])" -quoted_string_re = '"(?:(?:' + qdtext_re + ")|(?:" + quoted_pair_re + '))*"' - -quoted_string = re.compile(quoted_string_re) -quoted_pair = re.compile(quoted_pair_re) - - -def undquote(value): - if value.startswith('"') and value.endswith('"'): - # So it claims to be DQUOTE'ed, let's validate that - matches = quoted_string.match(value) - - if matches and matches.end() == len(value): - # Remove the DQUOTE's from the value - value = value[1:-1] - - # Remove all backslashes that are followed by a valid vchar or - # obs-text - value = quoted_pair.sub(r"\1", value) - - return value - elif not value.startswith('"') and not value.endswith('"'): - return value - - raise ValueError("Invalid quoting in value") - - -def cleanup_unix_socket(path): - try: - st = os.stat(path) - except OSError as exc: - if exc.errno != errno.ENOENT: - raise # pragma: no cover - else: - if stat.S_ISSOCK(st.st_mode): - try: - os.remove(path) - except OSError: # pragma: no cover - # avoid race condition error during tests - pass - - -class Error(object): - code = 500 - reason = "Internal Server Error" - - def __init__(self, body): - self.body = body - - def to_response(self): - status = "%s %s" % (self.code, self.reason) - body = "%s\r\n\r\n%s" % (self.reason, self.body) - tag = "\r\n\r\n(generated by waitress)" - body = body + tag - headers = [("Content-Type", "text/plain")] - - return status, headers, body - - def wsgi_response(self, environ, start_response): - status, headers, body = self.to_response() - start_response(status, headers) - yield body - - -class BadRequest(Error): - code = 400 - reason = "Bad Request" - - -class RequestHeaderFieldsTooLarge(BadRequest): - code = 431 - reason = "Request Header Fields Too Large" - - -class RequestEntityTooLarge(BadRequest): - code = 413 - reason = "Request Entity Too Large" - - -class InternalServerError(Error): - code = 500 - reason = "Internal Server Error" - - -class ServerNotImplemented(Error): - code = 501 - reason = "Not Implemented" diff --git a/waitress/wasyncore.py b/waitress/wasyncore.py deleted file mode 100644 index 09bcafa..0000000 --- a/waitress/wasyncore.py +++ /dev/null @@ -1,693 +0,0 @@ -# -*- Mode: Python -*- -# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp -# Author: Sam Rushing - -# ====================================================================== -# Copyright 1996 by Sam Rushing -# -# All Rights Reserved -# -# Permission to use, copy, modify, and distribute this software and -# its documentation for any purpose and without fee is hereby -# granted, provided that the above copyright notice appear in all -# copies and that both that copyright notice and this permission -# notice appear in supporting documentation, and that the name of Sam -# Rushing not be used in advertising or publicity pertaining to -# distribution of the software without specific, written prior -# permission. -# -# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, -# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN -# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR -# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS -# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, -# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN -# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -# ====================================================================== - -"""Basic infrastructure for asynchronous socket service clients and servers. - -There are only two ways to have a program on a single processor do "more -than one thing at a time". Multi-threaded programming is the simplest and -most popular way to do it, but there is another very different technique, -that lets you have nearly all the advantages of multi-threading, without -actually using multiple threads. it's really only practical if your program -is largely I/O bound. If your program is CPU bound, then pre-emptive -scheduled threads are probably what you really need. Network servers are -rarely CPU-bound, however. - -If your operating system supports the select() system call in its I/O -library (and nearly all do), then you can use it to juggle multiple -communication channels at once; doing other work while your I/O is taking -place in the "background." Although this strategy can seem strange and -complex, especially at first, it is in many ways easier to understand and -control than multi-threaded programming. The module documented here solves -many of the difficult problems for you, making the task of building -sophisticated high-performance network servers and clients a snap. - -NB: this is a fork of asyncore from the stdlib that we've (the waitress -developers) named 'wasyncore' to ensure forward compatibility, as asyncore -in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore -nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. -""" - -from . import compat -from . import utilities - -import logging -import select -import socket -import sys -import time -import warnings - -import os -from errno import ( - EALREADY, - EINPROGRESS, - EWOULDBLOCK, - ECONNRESET, - EINVAL, - ENOTCONN, - ESHUTDOWN, - EISCONN, - EBADF, - ECONNABORTED, - EPIPE, - EAGAIN, - EINTR, - errorcode, -) - -_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) - -try: - socket_map -except NameError: - socket_map = {} - - -def _strerror(err): - try: - return os.strerror(err) - except (TypeError, ValueError, OverflowError, NameError): - return "Unknown error %s" % err - - -class ExitNow(Exception): - pass - - -_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) - - -def read(obj): - try: - obj.handle_read_event() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def write(obj): - try: - obj.handle_write_event() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def _exception(obj): - try: - obj.handle_expt_event() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def readwrite(obj, flags): - try: - if flags & select.POLLIN: - obj.handle_read_event() - if flags & select.POLLOUT: - obj.handle_write_event() - if flags & select.POLLPRI: - obj.handle_expt_event() - if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): - obj.handle_close() - except socket.error as e: - if e.args[0] not in _DISCONNECTED: - obj.handle_error() - else: - obj.handle_close() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def poll(timeout=0.0, map=None): - if map is None: # pragma: no cover - map = socket_map - if map: - r = [] - w = [] - e = [] - for fd, obj in list(map.items()): # list() call FBO py3 - is_r = obj.readable() - is_w = obj.writable() - if is_r: - r.append(fd) - # accepting sockets should not be writable - if is_w and not obj.accepting: - w.append(fd) - if is_r or is_w: - e.append(fd) - if [] == r == w == e: - time.sleep(timeout) - return - - try: - r, w, e = select.select(r, w, e, timeout) - except select.error as err: - if err.args[0] != EINTR: - raise - else: - return - - for fd in r: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - read(obj) - - for fd in w: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - write(obj) - - for fd in e: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - _exception(obj) - - -def poll2(timeout=0.0, map=None): - # Use the poll() support added to the select module in Python 2.0 - if map is None: # pragma: no cover - map = socket_map - if timeout is not None: - # timeout is in milliseconds - timeout = int(timeout * 1000) - pollster = select.poll() - if map: - for fd, obj in list(map.items()): - flags = 0 - if obj.readable(): - flags |= select.POLLIN | select.POLLPRI - # accepting sockets should not be writable - if obj.writable() and not obj.accepting: - flags |= select.POLLOUT - if flags: - pollster.register(fd, flags) - - try: - r = pollster.poll(timeout) - except select.error as err: - if err.args[0] != EINTR: - raise - r = [] - - for fd, flags in r: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - readwrite(obj, flags) - - -poll3 = poll2 # Alias for backward compatibility - - -def loop(timeout=30.0, use_poll=False, map=None, count=None): - if map is None: # pragma: no cover - map = socket_map - - if use_poll and hasattr(select, "poll"): - poll_fun = poll2 - else: - poll_fun = poll - - if count is None: # pragma: no cover - while map: - poll_fun(timeout, map) - - else: - while map and count > 0: - poll_fun(timeout, map) - count = count - 1 - - -def compact_traceback(): - t, v, tb = sys.exc_info() - tbinfo = [] - if not tb: # pragma: no cover - raise AssertionError("traceback does not exist") - while tb: - tbinfo.append( - ( - tb.tb_frame.f_code.co_filename, - tb.tb_frame.f_code.co_name, - str(tb.tb_lineno), - ) - ) - tb = tb.tb_next - - # just to be safe - del tb - - file, function, line = tbinfo[-1] - info = " ".join(["[%s|%s|%s]" % x for x in tbinfo]) - return (file, function, line), t, v, info - - -class dispatcher: - - debug = False - connected = False - accepting = False - connecting = False - closing = False - addr = None - ignore_log_types = frozenset({"warning"}) - logger = utilities.logger - compact_traceback = staticmethod(compact_traceback) # for testing - - def __init__(self, sock=None, map=None): - if map is None: # pragma: no cover - self._map = socket_map - else: - self._map = map - - self._fileno = None - - if sock: - # Set to nonblocking just to make sure for cases where we - # get a socket from a blocking source. - sock.setblocking(0) - self.set_socket(sock, map) - self.connected = True - # The constructor no longer requires that the socket - # passed be connected. - try: - self.addr = sock.getpeername() - except socket.error as err: - if err.args[0] in (ENOTCONN, EINVAL): - # To handle the case where we got an unconnected - # socket. - self.connected = False - else: - # The socket is broken in some unknown way, alert - # the user and remove it from the map (to prevent - # polling of broken sockets). - self.del_channel(map) - raise - else: - self.socket = None - - def __repr__(self): - status = [self.__class__.__module__ + "." + compat.qualname(self.__class__)] - if self.accepting and self.addr: - status.append("listening") - elif self.connected: - status.append("connected") - if self.addr is not None: - try: - status.append("%s:%d" % self.addr) - except TypeError: # pragma: no cover - status.append(repr(self.addr)) - return "<%s at %#x>" % (" ".join(status), id(self)) - - __str__ = __repr__ - - def add_channel(self, map=None): - # self.log_info('adding channel %s' % self) - if map is None: - map = self._map - map[self._fileno] = self - - def del_channel(self, map=None): - fd = self._fileno - if map is None: - map = self._map - if fd in map: - # self.log_info('closing channel %d:%s' % (fd, self)) - del map[fd] - self._fileno = None - - def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): - self.family_and_type = family, type - sock = socket.socket(family, type) - sock.setblocking(0) - self.set_socket(sock) - - def set_socket(self, sock, map=None): - self.socket = sock - self._fileno = sock.fileno() - self.add_channel(map) - - def set_reuse_addr(self): - # try to re-use a server port if possible - try: - self.socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_REUSEADDR, - self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, - ) - except socket.error: - pass - - # ================================================== - # predicates for select() - # these are used as filters for the lists of sockets - # to pass to select(). - # ================================================== - - def readable(self): - return True - - def writable(self): - return True - - # ================================================== - # socket object methods. - # ================================================== - - def listen(self, num): - self.accepting = True - if os.name == "nt" and num > 5: # pragma: no cover - num = 5 - return self.socket.listen(num) - - def bind(self, addr): - self.addr = addr - return self.socket.bind(addr) - - def connect(self, address): - self.connected = False - self.connecting = True - err = self.socket.connect_ex(address) - if ( - err in (EINPROGRESS, EALREADY, EWOULDBLOCK) - or err == EINVAL - and os.name == "nt" - ): # pragma: no cover - self.addr = address - return - if err in (0, EISCONN): - self.addr = address - self.handle_connect_event() - else: - raise socket.error(err, errorcode[err]) - - def accept(self): - # XXX can return either an address pair or None - try: - conn, addr = self.socket.accept() - except TypeError: - return None - except socket.error as why: - if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): - return None - else: - raise - else: - return conn, addr - - def send(self, data): - try: - result = self.socket.send(data) - return result - except socket.error as why: - if why.args[0] == EWOULDBLOCK: - return 0 - elif why.args[0] in _DISCONNECTED: - self.handle_close() - return 0 - else: - raise - - def recv(self, buffer_size): - try: - data = self.socket.recv(buffer_size) - if not data: - # a closed connection is indicated by signaling - # a read condition, and having recv() return 0. - self.handle_close() - return b"" - else: - return data - except socket.error as why: - # winsock sometimes raises ENOTCONN - if why.args[0] in _DISCONNECTED: - self.handle_close() - return b"" - else: - raise - - def close(self): - self.connected = False - self.accepting = False - self.connecting = False - self.del_channel() - if self.socket is not None: - try: - self.socket.close() - except socket.error as why: - if why.args[0] not in (ENOTCONN, EBADF): - raise - - # log and log_info may be overridden to provide more sophisticated - # logging and warning methods. In general, log is for 'hit' logging - # and 'log_info' is for informational, warning and error logging. - - def log(self, message): - self.logger.log(logging.DEBUG, message) - - def log_info(self, message, type="info"): - severity = { - "info": logging.INFO, - "warning": logging.WARN, - "error": logging.ERROR, - } - self.logger.log(severity.get(type, logging.INFO), message) - - def handle_read_event(self): - if self.accepting: - # accepting sockets are never connected, they "spawn" new - # sockets that are connected - self.handle_accept() - elif not self.connected: - if self.connecting: - self.handle_connect_event() - self.handle_read() - else: - self.handle_read() - - def handle_connect_event(self): - err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - raise socket.error(err, _strerror(err)) - self.handle_connect() - self.connected = True - self.connecting = False - - def handle_write_event(self): - if self.accepting: - # Accepting sockets shouldn't get a write event. - # We will pretend it didn't happen. - return - - if not self.connected: - if self.connecting: - self.handle_connect_event() - self.handle_write() - - def handle_expt_event(self): - # handle_expt_event() is called if there might be an error on the - # socket, or if there is OOB data - # check for the error condition first - err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - # we can get here when select.select() says that there is an - # exceptional condition on the socket - # since there is an error, we'll go ahead and close the socket - # like we would in a subclassed handle_read() that received no - # data - self.handle_close() - else: - self.handle_expt() - - def handle_error(self): - nil, t, v, tbinfo = self.compact_traceback() - - # sometimes a user repr method will crash. - try: - self_repr = repr(self) - except: # pragma: no cover - self_repr = "<__repr__(self) failed for object at %0x>" % id(self) - - self.log_info( - "uncaptured python exception, closing channel %s (%s:%s %s)" - % (self_repr, t, v, tbinfo), - "error", - ) - self.handle_close() - - def handle_expt(self): - self.log_info("unhandled incoming priority event", "warning") - - def handle_read(self): - self.log_info("unhandled read event", "warning") - - def handle_write(self): - self.log_info("unhandled write event", "warning") - - def handle_connect(self): - self.log_info("unhandled connect event", "warning") - - def handle_accept(self): - pair = self.accept() - if pair is not None: - self.handle_accepted(*pair) - - def handle_accepted(self, sock, addr): - sock.close() - self.log_info("unhandled accepted event", "warning") - - def handle_close(self): - self.log_info("unhandled close event", "warning") - self.close() - - -# --------------------------------------------------------------------------- -# adds simple buffered output capability, useful for simple clients. -# [for more sophisticated usage use asynchat.async_chat] -# --------------------------------------------------------------------------- - - -class dispatcher_with_send(dispatcher): - def __init__(self, sock=None, map=None): - dispatcher.__init__(self, sock, map) - self.out_buffer = b"" - - def initiate_send(self): - num_sent = 0 - num_sent = dispatcher.send(self, self.out_buffer[:65536]) - self.out_buffer = self.out_buffer[num_sent:] - - handle_write = initiate_send - - def writable(self): - return (not self.connected) or len(self.out_buffer) - - def send(self, data): - if self.debug: # pragma: no cover - self.log_info("sending %s" % repr(data)) - self.out_buffer = self.out_buffer + data - self.initiate_send() - - -def close_all(map=None, ignore_all=False): - if map is None: # pragma: no cover - map = socket_map - for x in list(map.values()): # list() FBO py3 - try: - x.close() - except socket.error as x: - if x.args[0] == EBADF: - pass - elif not ignore_all: - raise - except _reraised_exceptions: - raise - except: - if not ignore_all: - raise - map.clear() - - -# Asynchronous File I/O: -# -# After a little research (reading man pages on various unixen, and -# digging through the linux kernel), I've determined that select() -# isn't meant for doing asynchronous file i/o. -# Heartening, though - reading linux/mm/filemap.c shows that linux -# supports asynchronous read-ahead. So _MOST_ of the time, the data -# will be sitting in memory for us already when we go to read it. -# -# What other OS's (besides NT) support async file i/o? [VMS?] -# -# Regardless, this is useful for pipes, and stdin/stdout... - -if os.name == "posix": - - class file_wrapper: - # Here we override just enough to make a file - # look like a socket for the purposes of asyncore. - # The passed fd is automatically os.dup()'d - - def __init__(self, fd): - self.fd = os.dup(fd) - - def __del__(self): - if self.fd >= 0: - warnings.warn("unclosed file %r" % self, compat.ResourceWarning) - self.close() - - def recv(self, *args): - return os.read(self.fd, *args) - - def send(self, *args): - return os.write(self.fd, *args) - - def getsockopt(self, level, optname, buflen=None): # pragma: no cover - if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: - return 0 - raise NotImplementedError( - "Only asyncore specific behaviour " "implemented." - ) - - read = recv - write = send - - def close(self): - if self.fd < 0: - return - fd = self.fd - self.fd = -1 - os.close(fd) - - def fileno(self): - return self.fd - - class file_dispatcher(dispatcher): - def __init__(self, fd, map=None): - dispatcher.__init__(self, None, map) - self.connected = True - try: - fd = fd.fileno() - except AttributeError: - pass - self.set_file(fd) - # set it to non-blocking mode - compat.set_nonblocking(fd) - - def set_file(self, fd): - self.socket = file_wrapper(fd) - self._fileno = self.socket.fileno() - self.add_channel() -- cgit v1.2.1 From 1b1dac068b7f6bd74aa0002e800e95acb0fbe3c8 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 22:53:30 -0800 Subject: Add pyproject.toml to project --- pyproject.toml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7f50ece --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +requires = ["setuptools >= 41"] +build-backend = "setuptools.build_meta" + +[tool.black] +py36 = false +exclude = ''' +/( + \.git + | .tox +)/ +''' -- cgit v1.2.1 From e2facc893269d8a7492ae23ed047fb50b71ccd6b Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 22:54:07 -0800 Subject: Move from setup.py to setup.cfg --- setup.cfg | 61 ++++++++++++++++++++++++++++++++++++++++++++---- setup.py | 80 ++------------------------------------------------------------- 2 files changed, 58 insertions(+), 83 deletions(-) diff --git a/setup.cfg b/setup.cfg index 81cfbb1..ac1ceb8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,58 @@ -[easy_install] -zip_ok = false +[metadata] +name = waitress +version = 1.4.4a0 +description = Waitress WSGI server +long_description = file: README.rst, CHANGES.txt +long_description_content_type = text/x-rst +keywords = waitress wsgi server http +license = ZPL 2.1 +classifiers = + Development Status :: 6 - Mature + Environment :: Web Environment + Intended Audience :: Developers + License :: OSI Approved :: Zope Public License + Programming Language :: Python + Programming Language :: Python :: 2 + Programming Language :: Python :: 2.7 + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.5 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: Implementation :: CPython + Programming Language :: Python :: Implementation :: PyPy + Operating System :: OS Independent + Topic :: Internet :: WWW/HTTP + Topic :: Internet :: WWW/HTTP :: WSGI +url = https://github.com/Pylons/waitress +author = Zope Foundation and Contributors +author_email = zope-dev@zope.org +maintainer = Pylons Project +maintainer_email = pylons-discuss@googlegroups.com + +[options] +package_dir= + =src +packages=find: +python_requires = >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.* + +[options.entry_points] +paste.server_runner = + main = waitress:serve_paste +console_scripts = + waitress-serve = waitress.runner:run + +[options.packages.find] +where=src + +[options.extras_require] +testing = + nose + coverage>=5.0 + +docs = + Sphinx>=1.8.1 + docutils + pylons-sphinx-themes>=1.0.9 [nosetests] match=^test @@ -10,6 +63,4 @@ cover-erase=1 [bdist_wheel] universal = 1 -[aliases] -dev = develop easy_install waitress[testing] -docs = develop easy_install waitress[docs] + diff --git a/setup.py b/setup.py index 0468a82..6068493 100644 --- a/setup.py +++ b/setup.py @@ -1,79 +1,3 @@ -############################################################################## -# -# Copyright (c) 2006 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -import os +from setuptools import setup -from setuptools import find_packages, setup - -here = os.path.abspath(os.path.dirname(__file__)) -try: - README = open(os.path.join(here, "README.rst")).read() - CHANGES = open(os.path.join(here, "CHANGES.txt")).read() -except IOError: - README = CHANGES = "" - -docs_extras = [ - "Sphinx>=1.8.1", - "docutils", - "pylons-sphinx-themes>=1.0.9", -] - -testing_extras = [ - "nose", - "coverage>=5.0", -] - -setup( - name="waitress", - version="1.4.3", - author="Zope Foundation and Contributors", - author_email="zope-dev@zope.org", - maintainer="Pylons Project", - maintainer_email="pylons-discuss@googlegroups.com", - description="Waitress WSGI server", - long_description=README + "\n\n" + CHANGES, - license="ZPL 2.1", - keywords="waitress wsgi server http", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: Zope Public License", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Natural Language :: English", - "Operating System :: OS Independent", - "Topic :: Internet :: WWW/HTTP", - "Topic :: Internet :: WWW/HTTP :: WSGI", - ], - url="https://github.com/Pylons/waitress", - packages=find_packages(), - python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*", - extras_require={"testing": testing_extras, "docs": docs_extras}, - include_package_data=True, - test_suite="waitress", - zip_safe=False, - entry_points=""" - [paste.server_runner] - main = waitress:serve_paste - [console_scripts] - waitress-serve = waitress.runner:run - """, -) +setup() -- cgit v1.2.1 From cb42892839e2bc8f86cfeb68bf9334f508578cbd Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 22:54:26 -0800 Subject: Update tox.ini to new world order --- tox.ini | 101 ++++++++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 67 insertions(+), 34 deletions(-) diff --git a/tox.ini b/tox.ini index b9f7fa2..04ca1d5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,56 +1,89 @@ [tox] envlist = - py27,py35,py36,py37,py38,pypy,pypy3, + lint, + py27,pypy, + py35,py36,py37,py38,pypy3, docs, - {py27,py35}-cover,coverage + coverage +isolated_build = True [testenv] commands = python --version - nosetests --with-xunit --xunit-file=nosetests-{envname}.xml {posargs:waitress} - + nosetests {posargs:} extras = testing +setenv = + COVERAGE_FILE=.coverage.{envname} -[testenv:docs] -basepython = python3.5 -whitelist_externals = make +[testenv:coverage] +basepython = python3.8 commands = - make -C docs clean html epub BUILDDIR={envdir} "SPHINXOPTS=-W -E" -extras = - docs + coverage combine + coverage xml + coverage report --show-missing +deps = + coverage +setenv = + COVERAGE_FILE=.coverage +depends = py27, py35 -[py-cover] +[testenv:lint] +skip_install = True +basepython = python3.8 commands = - coverage run {envbindir}/nosetests waitress - coverage combine - coverage xml -o {envname}.xml + black --check --diff . + check-manifest + # build sdist/wheel + python -m pep517.build . + twine check dist/* +deps = + black + readme_renderer + check-manifest + pep517 + twine +[testenv:docs] +whitelist_externals = + make +commands = + make -C docs html BUILDDIR={envdir} SPHINXOPTS="-W -E" extras = - testing + docs -[testenv:py27-cover] +[testenv:run-flake8] +skip_install = True +basepython = python3.8 commands = - {[py-cover]commands} - -setenv = - COVERAGE_FILE=.coverage.py2 + flake8 src/waitress/ tests +deps = + flake8 + flake8-bugbear -[testenv:py35-cover] +[testenv:run-black] +skip_install = True +basepython = python3.8 commands = - {[py-cover]commands} - -setenv = - COVERAGE_FILE=.coverage.py3 + black . +deps = + black -[testenv:coverage] -basepython = python3.5 +[testenv:build] +skip_install = true +basepython = python3.8 commands = - coverage combine - coverage xml - coverage report --show-missing --fail-under=100 -deps = - coverage -setenv = - COVERAGE_FILE=.coverage + # clean up build/ and dist/ folders + python -c 'import shutil; shutil.rmtree("build", ignore_errors=True)' + # Make sure we aren't forgetting anything + check-manifest + # build sdist/wheel + python -m pep517.build . + # Verify all is well + twine check dist/* +deps = + readme_renderer + check-manifest + pep517 + twine -- cgit v1.2.1 From b9b917405d1b183726ced687cea481f03bb73a9f Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:23:17 -0800 Subject: Blacken files --- src/waitress/parser.py | 1 + tests/test_functional.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/waitress/parser.py b/src/waitress/parser.py index 53072b5..765fe59 100644 --- a/src/waitress/parser.py +++ b/src/waitress/parser.py @@ -39,6 +39,7 @@ class ParsingError(Exception): class TransferEncodingNotImplemented(Exception): pass + class HTTPRequestParser(object): """A structure that collects the HTTP request. diff --git a/tests/test_functional.py b/tests/test_functional.py index 4b60676..e894497 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -38,6 +38,7 @@ def try_register_coverage(): # pragma: no cover # atexit handler by always registering a signal handler if "COVERAGE_PROCESS_START" in os.environ: + def sigterm(*args): sys.exit(0) @@ -336,7 +337,8 @@ class EchoTests(object): cl = int(headers["content-length"]) self.assertEqual(cl, len(response_body)) self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], ) self.assertEqual(headers["content-type"], "text/plain") # connection has been closed @@ -361,7 +363,8 @@ class EchoTests(object): self.assertEqual(cl, len(response_body)) self.assertTrue(b"Chunk not properly terminated" in response_body) self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], ) self.assertEqual(headers["content-type"], "text/plain") # connection has been closed @@ -897,7 +900,9 @@ class TooLargeTests(object): def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" + to_send = ( + "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" + ) to_send += body to_send = tobytes(to_send) self.connect() @@ -1094,7 +1099,8 @@ class InternalServerErrorTests(object): self.assertEqual(cl, len(response_body)) self.assertTrue(response_body.startswith(b"Internal Server Error")) self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], ) # connection has been closed self.send_check_error(to_send) @@ -1153,7 +1159,8 @@ class InternalServerErrorTests(object): self.assertEqual(cl, len(response_body)) self.assertTrue(response_body.startswith(b"Internal Server Error")) self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], ) # connection has been closed self.send_check_error(to_send) -- cgit v1.2.1 From 3de5fdee7c04c1f4a799e3e13b4fb8d646d32a94 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:23:59 -0800 Subject: paths are equal in coverage --- .coveragerc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index cbacd63..b869097 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,10 +5,15 @@ concurrency = multiprocessing source = waitress - omit = waitress/tests/fixtureapps/getline.py +[paths] +source = + src/waitress + */src/waitress + */site-packages/waitress + [report] show_missing = true precision = 2 -- cgit v1.2.1 From 1de48d6677cfde87a11fdfc14465cd98a37fc0cf Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:24:28 -0800 Subject: Switch to pytest from nosetests --- setup.cfg | 17 +++++++++-------- tox.ini | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index ac1ceb8..77afa15 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,7 +46,8 @@ where=src [options.extras_require] testing = - nose + pytest + pytest-cover coverage>=5.0 docs = @@ -54,13 +55,13 @@ docs = docutils pylons-sphinx-themes>=1.0.9 -[nosetests] -match=^test -nocapture=1 -cover-package=waitress -cover-erase=1 - [bdist_wheel] universal = 1 - +[tool:pytest] +python_files = test_*.py +# For the benefit of test_wasyncore.py +python_classes = Test_* +testpaths = + tests +addopts = -W always --cov --cov-report=term-missing diff --git a/tox.ini b/tox.ini index 04ca1d5..8db0f88 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ isolated_build = True [testenv] commands = python --version - nosetests {posargs:} + pytest {posargs:} extras = testing setenv = -- cgit v1.2.1 From 85c100ff2f580b4eecb1509bed4e6753827ff29e Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:25:21 -0800 Subject: Update gitignore --- .gitignore | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.gitignore b/.gitignore index 3a33b6c..76d521b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,13 +3,7 @@ env*/ .coverage .coverage.* -.idea/ .tox/ -nosetests.xml -waitress/coverage.xml dist/ -keep/ build/ coverage.xml -nosetests*.xml -py*-cover.xml -- cgit v1.2.1 From a48e996f0cdbd4530bfbe9e410f06bb7f4a1ba28 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:28:40 -0800 Subject: Update CI to add lint checks --- .github/workflows/ci-linux.yml | 18 +++++++++++++++--- .travis.yml | 10 ++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index cc5639a..1a8897f 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -40,10 +40,10 @@ jobs: - name: Setup python uses: actions/setup-python@v1 with: - python-version: 3.5 + python-version: 3.8 architecture: x64 - run: pip install tox - - run: tox -e py35-cover,py27-cover,coverage + - run: tox -e py38,py27,coverage docs: runs-on: ubuntu-latest name: Build the documentation @@ -52,7 +52,19 @@ jobs: - name: Setup python uses: actions/setup-python@v1 with: - python-version: 3.5 + python-version: 3.8 architecture: x64 - run: pip install tox - run: tox -e docs + lint: + runs-on: ubuntu-latest + name: Lint the package + steps: + - uses: actions/checkout@master + - name: Setup python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + architecture: x64 + - run: pip install tox + - run: tox -e lint diff --git a/.travis.yml b/.travis.yml index 4bb2567..464fd86 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,8 +22,14 @@ matrix: env: TOXENV=pypy - python: pypy3 env: TOXENV=pypy3 - - python: 3.5 - env: TOXENV=py27-cover,py35-cover,coverage + - python: 3.8 + env: TOXENV=py27,py38,coverage + dist: xenial + sudo: true + - python: 3.8 + env: TOXENV=lint + dist: xenial + sudo: true - python: 3.5 env: TOXENV=docs -- cgit v1.2.1 From 22556f034a812216ca91a32a74db583e765959ad Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:32:53 -0800 Subject: Update .gitignore --- .gitignore | 2 ++ docs/.gitignore | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) delete mode 100644 docs/.gitignore diff --git a/.gitignore b/.gitignore index 76d521b..146736f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ env*/ dist/ build/ coverage.xml +docs/_themes +docs/_build diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index da7abd0..0000000 --- a/docs/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -_themes -_build - - -- cgit v1.2.1 From 4580fd20b2580b7c6c340fe7b4777e0604799ddd Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Sun, 2 Feb 2020 23:36:08 -0800 Subject: Add MANIFEST.in --- MANIFEST.in | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..0332267 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,20 @@ +graft src/waitress +graft tests +graft docs + +include README.rst +include CHANGES.txt +include HISTORY.txt +include RELEASING.txt +include LICENSE.txt +include contributing.md +include CONTRIBUTORS.txt +include COPYRIGHT.txt + +include pyproject.toml setup.cfg +include .coveragerc +include tox.ini .travis.yml rtd.txt appveyor.yml + +exclude TODO.txt + +recursive-exclude * __pycache__ *.py[cod] -- cgit v1.2.1 From 975d9ca3ed7805480358ce184d6ced065fe4dff8 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Mon, 3 Feb 2020 00:05:19 -0800 Subject: Make sure to include github workflows --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index 0332267..2763691 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ graft src/waitress graft tests graft docs +graft .github include README.rst include CHANGES.txt -- cgit v1.2.1 From c141ba2b938633548a79bb7174f91572d710f38c Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Thu, 16 Apr 2020 22:53:50 -0700 Subject: Add flake8 configuration --- .flake8 | 36 ++++++++++++++++++++++++++++++++++++ MANIFEST.in | 2 +- tox.ini | 3 +++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..a5f3a73 --- /dev/null +++ b/.flake8 @@ -0,0 +1,36 @@ +# Recommended flake8 settings while editing, we use Black for the final linting/say in how code is formatted +# +# pip install flake8 flake8-bugbear +# +# This will warn/error on things that black does not fix, on purpose. +# +# Run: +# +# tox -e run-flake8 +# +# To have it automatically create and install the appropriate tools, and run +# flake8 across the source code/tests + +[flake8] +# max line length is set to 88 in black, here it is set to 80 and we enable bugbear's B950 warning, which is: +# +# B950: Line too long. This is a pragmatic equivalent of pycodestyle’s E501: it +# considers “max-line-length” but only triggers when the value has been +# exceeded by more than 10%. You will no longer be forced to reformat code due +# to the closing parenthesis being one character too far to satisfy the linter. +# At the same time, if you do significantly violate the line length, you will +# receive a message that states what the actual limit is. This is inspired by +# Raymond Hettinger’s “Beyond PEP 8” talk and highway patrol not stopping you +# if you drive < 5mph too fast. Disable E501 to avoid duplicate warnings. +max-line-length = 80 +max-complexity = 12 +select = E,F,W,C,B,B9 +ignore = + # E123 closing bracket does not match indentation of opening bracket’s line + E123 + # E203 whitespace before ‘:’ (Not PEP8 compliant, Python Black) + E203 + # E501 line too long (82 > 79 characters) (replaced by B950 from flake8-bugbear, https://github.com/PyCQA/flake8-bugbear) + E501 + # W503 line break before binary operator (Not PEP8 compliant, Python Black) + W503 diff --git a/MANIFEST.in b/MANIFEST.in index 2763691..b52891d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -13,7 +13,7 @@ include CONTRIBUTORS.txt include COPYRIGHT.txt include pyproject.toml setup.cfg -include .coveragerc +include .coveragerc .flake8 include tox.ini .travis.yml rtd.txt appveyor.yml exclude TODO.txt diff --git a/tox.ini b/tox.ini index 8db0f88..08baf49 100644 --- a/tox.ini +++ b/tox.ini @@ -34,6 +34,7 @@ basepython = python3.8 commands = black --check --diff . check-manifest + # flake8 src/waitress/ tests # build sdist/wheel python -m pep517.build . twine check dist/* @@ -43,6 +44,8 @@ deps = check-manifest pep517 twine + flake8 + flake8-bugbear [testenv:docs] whitelist_externals = -- cgit v1.2.1 From a2962aed85e5dbe1688a93b2f7f76202cd395eb3 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Thu, 16 Apr 2020 22:56:26 -0700 Subject: Remove travis and appveyor from this repo --- .travis.yml | 47 ----------------------------------------------- MANIFEST.in | 2 +- appveyor.yml | 27 --------------------------- 3 files changed, 1 insertion(+), 75 deletions(-) delete mode 100644 .travis.yml delete mode 100644 appveyor.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 464fd86..0000000 --- a/.travis.yml +++ /dev/null @@ -1,47 +0,0 @@ -# Wire up travis -language: python -sudo: false - -matrix: - include: - - python: 2.7 - env: TOXENV=py27 - - python: 3.5 - env: TOXENV=py35 - - python: 3.6 - env: TOXENV=py36 - - python: 3.7 - env: TOXENV=py37 - dist: xenial - sudo: true - - python: 3.8 - env: TOXENV=py38 - dist: xenial - sudo: true - - python: pypy - env: TOXENV=pypy - - python: pypy3 - env: TOXENV=pypy3 - - python: 3.8 - env: TOXENV=py27,py38,coverage - dist: xenial - sudo: true - - python: 3.8 - env: TOXENV=lint - dist: xenial - sudo: true - - python: 3.5 - env: TOXENV=docs - -install: - - travis_retry pip install tox - -script: - - travis_retry tox - -notifications: - email: - - pyramid-checkins@lists.repoze.org - irc: - channels: - - "chat.freenode.net#pyramid" diff --git a/MANIFEST.in b/MANIFEST.in index b52891d..7540038 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -14,7 +14,7 @@ include COPYRIGHT.txt include pyproject.toml setup.cfg include .coveragerc .flake8 -include tox.ini .travis.yml rtd.txt appveyor.yml +include tox.ini rtd.txt exclude TODO.txt diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 48cb759..0000000 --- a/appveyor.yml +++ /dev/null @@ -1,27 +0,0 @@ -environment: - matrix: - - PYTHON: "C:\\Python27" - TOXENV: "py27" - - PYTHON: "C:\\Python27-x64" - TOXENV: "py27" - - PYTHON: "C:\\Python36" - TOXENV: "py36" - - PYTHON: "C:\\Python36-x64" - TOXENV: "py36" - - PYTHON: "C:\\Python37" - TOXENV: "py37" - - PYTHON: "C:\\Python37-x64" - TOXENV: "py37" - -cache: - - '%LOCALAPPDATA%\pip\Cache' - -version: '{branch}.{build}' - -install: - - "%PYTHON%\\python.exe -m pip install tox" - -build: off - -test_script: - - "%PYTHON%\\Scripts\\tox.exe" -- cgit v1.2.1