diff options
Diffstat (limited to 'tests/test_channel.py')
-rw-r--r-- | tests/test_channel.py | 158 |
1 files changed, 146 insertions, 12 deletions
diff --git a/tests/test_channel.py b/tests/test_channel.py index df3d450..d86dbbe 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,6 +1,8 @@ import io import unittest +import pytest + class TestHTTPChannel(unittest.TestCase): def _makeOne(self, sock, addr, adj, map=None): @@ -173,7 +175,7 @@ class TestHTTPChannel(unittest.TestCase): def test_readable_with_requests(self): inst, sock, map = self._makeOneWithMap() - inst.requests = True + inst.requests = [True] self.assertEqual(inst.readable(), False) def test_handle_read_no_error(self): @@ -189,8 +191,6 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(L, [b"abc"]) def test_handle_read_error(self): - import socket - inst, sock, map = self._makeOneWithMap() inst.will_close = False @@ -439,7 +439,7 @@ class TestHTTPChannel(unittest.TestCase): preq.completed = False preq.empty = True inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.requests, ()) + self.assertEqual(inst.requests, []) self.assertEqual(inst.server.tasks, []) def test_received_preq_completed_empty(self): @@ -525,14 +525,6 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(inst.sent_continue, True) self.assertEqual(preq.completed, False) - def test_service_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - def test_service_with_one_request(self): inst, sock, map = self._makeOneWithMap() request = DummyRequest() @@ -561,6 +553,7 @@ class TestHTTPChannel(unittest.TestCase): inst.task_class = DummyTaskClass() inst.requests = [request1, request2] inst.service() + inst.service() self.assertEqual(inst.requests, []) self.assertTrue(request1.serviced) self.assertTrue(request2.serviced) @@ -705,6 +698,137 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(inst.requests, []) +class TestHTTPChannelLookahead(TestHTTPChannel): + def app_check_disconnect(self, environ, start_response): + """ + Application that checks for client disconnection every + second for up to two seconds. + """ + import time + + if hasattr(self, "app_started"): + self.app_started.set() + + try: + request_body_size = int(environ.get("CONTENT_LENGTH", 0)) + except ValueError: + request_body_size = 0 + self.request_body = environ["wsgi.input"].read(request_body_size) + + self.disconnect_detected = False + check = environ["waitress.client_disconnected"] + if environ["PATH_INFO"] == "/sleep": + for i in range(3): + if i != 0: + time.sleep(1) + if check(): + self.disconnect_detected = True + break + + body = b"finished" + cl = str(len(body)) + start_response( + "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] + ) + return [body] + + def _make_app_with_lookahead(self): + """ + Setup a channel with lookahead and store it and the socket in self + """ + adj = DummyAdjustments() + adj.channel_request_lookahead = 5 + channel, sock, map = self._makeOneWithMap(adj=adj) + channel.server.application = self.app_check_disconnect + + self.channel = channel + self.sock = sock + + def _send(self, *lines): + """ + Send lines through the socket with correct line endings + """ + self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii")) + + def test_client_disconnect(self, close_before_start=False): + """Disconnect the socket after starting the task.""" + import threading + + self._make_app_with_lookahead() + self._send( + "GET /sleep HTTP/1.1", + "Host: localhost:8080", + "", + ) + self.assertTrue(self.channel.readable()) + self.channel.handle_read() + self.assertEqual(len(self.channel.server.tasks), 1) + self.app_started = threading.Event() + self.disconnect_detected = False + thread = threading.Thread(target=self.channel.server.tasks[0].service) + + if not close_before_start: + thread.start() + self.assertTrue(self.app_started.wait(timeout=5)) + + # Close the socket, check that the channel is still readable due to the + # lookahead and read it, which marks the channel as closed. + self.sock.close() + self.assertTrue(self.channel.readable()) + self.channel.handle_read() + + if close_before_start: + thread.start() + + thread.join() + + if close_before_start: + self.assertFalse(self.app_started.is_set()) + else: + self.assertTrue(self.disconnect_detected) + + def test_client_disconnect_immediate(self): + """ + The same test, but this time we close the socket even before processing + started. The app should not be executed. + """ + self.test_client_disconnect(close_before_start=True) + + def test_lookahead_continue(self): + """ + Send two requests to a channel with lookahead and use an + expect-continue on the second one, making sure the responses still come + in the right order. + """ + self._make_app_with_lookahead() + self._send( + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "", + "x", + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "Expect: 100-continue", + "", + ) + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + data = self.sock.recv(256).decode("ascii") + self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n")) + + self.sock.send(b"x") + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + self.channel._flush_some() + data = self.sock.recv(256).decode("ascii") + self.assertEqual(data.split("\r\n")[-1], "finished") + self.assertEqual(self.request_body, b"x") + + class DummySock: blocking = False closed = False @@ -731,6 +855,11 @@ class DummySock: self.sent += data return len(data) + def recv(self, buffer_size): + result = self.sent[:buffer_size] + self.sent = self.sent[buffer_size:] + return result + class DummyLock: notified = False @@ -796,11 +925,16 @@ class DummyAdjustments: expose_tracebacks = True ident = "waitress" max_request_header_size = 10000 + url_prefix = "" + channel_request_lookahead = 0 + max_request_body_size = 1048576 class DummyServer: trigger_pulled = False adj = DummyAdjustments() + effective_port = 8080 + server_name = "" def __init__(self): self.tasks = [] |