summaryrefslogtreecommitdiff
path: root/tests/test_channel.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_channel.py')
-rw-r--r--tests/test_channel.py158
1 files changed, 146 insertions, 12 deletions
diff --git a/tests/test_channel.py b/tests/test_channel.py
index df3d450..d86dbbe 100644
--- a/tests/test_channel.py
+++ b/tests/test_channel.py
@@ -1,6 +1,8 @@
import io
import unittest
+import pytest
+
class TestHTTPChannel(unittest.TestCase):
def _makeOne(self, sock, addr, adj, map=None):
@@ -173,7 +175,7 @@ class TestHTTPChannel(unittest.TestCase):
def test_readable_with_requests(self):
inst, sock, map = self._makeOneWithMap()
- inst.requests = True
+ inst.requests = [True]
self.assertEqual(inst.readable(), False)
def test_handle_read_no_error(self):
@@ -189,8 +191,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(L, [b"abc"])
def test_handle_read_error(self):
- import socket
-
inst, sock, map = self._makeOneWithMap()
inst.will_close = False
@@ -439,7 +439,7 @@ class TestHTTPChannel(unittest.TestCase):
preq.completed = False
preq.empty = True
inst.received(b"GET / HTTP/1.1\r\n\r\n")
- self.assertEqual(inst.requests, ())
+ self.assertEqual(inst.requests, [])
self.assertEqual(inst.server.tasks, [])
def test_received_preq_completed_empty(self):
@@ -525,14 +525,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.sent_continue, True)
self.assertEqual(preq.completed, False)
- def test_service_no_requests(self):
- inst, sock, map = self._makeOneWithMap()
- inst.requests = []
- inst.service()
- self.assertEqual(inst.requests, [])
- self.assertTrue(inst.server.trigger_pulled)
- self.assertTrue(inst.last_activity)
-
def test_service_with_one_request(self):
inst, sock, map = self._makeOneWithMap()
request = DummyRequest()
@@ -561,6 +553,7 @@ class TestHTTPChannel(unittest.TestCase):
inst.task_class = DummyTaskClass()
inst.requests = [request1, request2]
inst.service()
+ inst.service()
self.assertEqual(inst.requests, [])
self.assertTrue(request1.serviced)
self.assertTrue(request2.serviced)
@@ -705,6 +698,137 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.requests, [])
+class TestHTTPChannelLookahead(TestHTTPChannel):
+ def app_check_disconnect(self, environ, start_response):
+ """
+ Application that checks for client disconnection every
+ second for up to two seconds.
+ """
+ import time
+
+ if hasattr(self, "app_started"):
+ self.app_started.set()
+
+ try:
+ request_body_size = int(environ.get("CONTENT_LENGTH", 0))
+ except ValueError:
+ request_body_size = 0
+ self.request_body = environ["wsgi.input"].read(request_body_size)
+
+ self.disconnect_detected = False
+ check = environ["waitress.client_disconnected"]
+ if environ["PATH_INFO"] == "/sleep":
+ for i in range(3):
+ if i != 0:
+ time.sleep(1)
+ if check():
+ self.disconnect_detected = True
+ break
+
+ body = b"finished"
+ cl = str(len(body))
+ start_response(
+ "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]
+ )
+ return [body]
+
+ def _make_app_with_lookahead(self):
+ """
+ Setup a channel with lookahead and store it and the socket in self
+ """
+ adj = DummyAdjustments()
+ adj.channel_request_lookahead = 5
+ channel, sock, map = self._makeOneWithMap(adj=adj)
+ channel.server.application = self.app_check_disconnect
+
+ self.channel = channel
+ self.sock = sock
+
+ def _send(self, *lines):
+ """
+ Send lines through the socket with correct line endings
+ """
+ self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii"))
+
+ def test_client_disconnect(self, close_before_start=False):
+ """Disconnect the socket after starting the task."""
+ import threading
+
+ self._make_app_with_lookahead()
+ self._send(
+ "GET /sleep HTTP/1.1",
+ "Host: localhost:8080",
+ "",
+ )
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.server.tasks), 1)
+ self.app_started = threading.Event()
+ self.disconnect_detected = False
+ thread = threading.Thread(target=self.channel.server.tasks[0].service)
+
+ if not close_before_start:
+ thread.start()
+ self.assertTrue(self.app_started.wait(timeout=5))
+
+ # Close the socket, check that the channel is still readable due to the
+ # lookahead and read it, which marks the channel as closed.
+ self.sock.close()
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+
+ if close_before_start:
+ thread.start()
+
+ thread.join()
+
+ if close_before_start:
+ self.assertFalse(self.app_started.is_set())
+ else:
+ self.assertTrue(self.disconnect_detected)
+
+ def test_client_disconnect_immediate(self):
+ """
+ The same test, but this time we close the socket even before processing
+ started. The app should not be executed.
+ """
+ self.test_client_disconnect(close_before_start=True)
+
+ def test_lookahead_continue(self):
+ """
+ Send two requests to a channel with lookahead and use an
+ expect-continue on the second one, making sure the responses still come
+ in the right order.
+ """
+ self._make_app_with_lookahead()
+ self._send(
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "",
+ "x",
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "Expect: 100-continue",
+ "",
+ )
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n"))
+
+ self.sock.send(b"x")
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ self.channel._flush_some()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertEqual(data.split("\r\n")[-1], "finished")
+ self.assertEqual(self.request_body, b"x")
+
+
class DummySock:
blocking = False
closed = False
@@ -731,6 +855,11 @@ class DummySock:
self.sent += data
return len(data)
+ def recv(self, buffer_size):
+ result = self.sent[:buffer_size]
+ self.sent = self.sent[buffer_size:]
+ return result
+
class DummyLock:
notified = False
@@ -796,11 +925,16 @@ class DummyAdjustments:
expose_tracebacks = True
ident = "waitress"
max_request_header_size = 10000
+ url_prefix = ""
+ channel_request_lookahead = 0
+ max_request_body_size = 1048576
class DummyServer:
trigger_pulled = False
adj = DummyAdjustments()
+ effective_port = 8080
+ server_name = ""
def __init__(self):
self.tasks = []