summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorViktor Dick <vd@perfact.de>2020-08-11 21:10:34 +0200
committerViktor Dick <vd@perfact.de>2020-10-25 09:36:34 +0100
commita514a47e55f0a663477831d083f608a7d2db035b (patch)
treee5df4d9cf7db0cf4e700fe53b959f5b9d8df485d /tests
parent5570c7c9931d32f7558e08143d32125a745ec56b (diff)
downloadwaitress-a514a47e55f0a663477831d083f608a7d2db035b.tar.gz
Allow tasks to notice if client disconnected
This inserts a callable `waitress.client_disconnected` into the environment that allows the task to check if the client disconnected while waiting for the response at strategic points in the execution, allowing to cancel the operation. It requires setting the new adjustment `channel_request_lookahead` to a value larger than 0, which continues to read requests from a channel even if a request is already being processed on that channel, up to the given count, since a client disconnect is detected by reading from a readable socket and receiving an empty result.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_channel.py158
-rw-r--r--tests/test_task.py5
2 files changed, 151 insertions, 12 deletions
diff --git a/tests/test_channel.py b/tests/test_channel.py
index df3d450..d86dbbe 100644
--- a/tests/test_channel.py
+++ b/tests/test_channel.py
@@ -1,6 +1,8 @@
import io
import unittest
+import pytest
+
class TestHTTPChannel(unittest.TestCase):
def _makeOne(self, sock, addr, adj, map=None):
@@ -173,7 +175,7 @@ class TestHTTPChannel(unittest.TestCase):
def test_readable_with_requests(self):
inst, sock, map = self._makeOneWithMap()
- inst.requests = True
+ inst.requests = [True]
self.assertEqual(inst.readable(), False)
def test_handle_read_no_error(self):
@@ -189,8 +191,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(L, [b"abc"])
def test_handle_read_error(self):
- import socket
-
inst, sock, map = self._makeOneWithMap()
inst.will_close = False
@@ -439,7 +439,7 @@ class TestHTTPChannel(unittest.TestCase):
preq.completed = False
preq.empty = True
inst.received(b"GET / HTTP/1.1\r\n\r\n")
- self.assertEqual(inst.requests, ())
+ self.assertEqual(inst.requests, [])
self.assertEqual(inst.server.tasks, [])
def test_received_preq_completed_empty(self):
@@ -525,14 +525,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.sent_continue, True)
self.assertEqual(preq.completed, False)
- def test_service_no_requests(self):
- inst, sock, map = self._makeOneWithMap()
- inst.requests = []
- inst.service()
- self.assertEqual(inst.requests, [])
- self.assertTrue(inst.server.trigger_pulled)
- self.assertTrue(inst.last_activity)
-
def test_service_with_one_request(self):
inst, sock, map = self._makeOneWithMap()
request = DummyRequest()
@@ -561,6 +553,7 @@ class TestHTTPChannel(unittest.TestCase):
inst.task_class = DummyTaskClass()
inst.requests = [request1, request2]
inst.service()
+ inst.service()
self.assertEqual(inst.requests, [])
self.assertTrue(request1.serviced)
self.assertTrue(request2.serviced)
@@ -705,6 +698,137 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.requests, [])
+class TestHTTPChannelLookahead(TestHTTPChannel):
+ def app_check_disconnect(self, environ, start_response):
+ """
+ Application that checks for client disconnection every
+ second for up to two seconds.
+ """
+ import time
+
+ if hasattr(self, "app_started"):
+ self.app_started.set()
+
+ try:
+ request_body_size = int(environ.get("CONTENT_LENGTH", 0))
+ except ValueError:
+ request_body_size = 0
+ self.request_body = environ["wsgi.input"].read(request_body_size)
+
+ self.disconnect_detected = False
+ check = environ["waitress.client_disconnected"]
+ if environ["PATH_INFO"] == "/sleep":
+ for i in range(3):
+ if i != 0:
+ time.sleep(1)
+ if check():
+ self.disconnect_detected = True
+ break
+
+ body = b"finished"
+ cl = str(len(body))
+ start_response(
+ "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]
+ )
+ return [body]
+
+ def _make_app_with_lookahead(self):
+ """
+ Setup a channel with lookahead and store it and the socket in self
+ """
+ adj = DummyAdjustments()
+ adj.channel_request_lookahead = 5
+ channel, sock, map = self._makeOneWithMap(adj=adj)
+ channel.server.application = self.app_check_disconnect
+
+ self.channel = channel
+ self.sock = sock
+
+ def _send(self, *lines):
+ """
+ Send lines through the socket with correct line endings
+ """
+ self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii"))
+
+ def test_client_disconnect(self, close_before_start=False):
+ """Disconnect the socket after starting the task."""
+ import threading
+
+ self._make_app_with_lookahead()
+ self._send(
+ "GET /sleep HTTP/1.1",
+ "Host: localhost:8080",
+ "",
+ )
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.server.tasks), 1)
+ self.app_started = threading.Event()
+ self.disconnect_detected = False
+ thread = threading.Thread(target=self.channel.server.tasks[0].service)
+
+ if not close_before_start:
+ thread.start()
+ self.assertTrue(self.app_started.wait(timeout=5))
+
+ # Close the socket, check that the channel is still readable due to the
+ # lookahead and read it, which marks the channel as closed.
+ self.sock.close()
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+
+ if close_before_start:
+ thread.start()
+
+ thread.join()
+
+ if close_before_start:
+ self.assertFalse(self.app_started.is_set())
+ else:
+ self.assertTrue(self.disconnect_detected)
+
+ def test_client_disconnect_immediate(self):
+ """
+ The same test, but this time we close the socket even before processing
+ started. The app should not be executed.
+ """
+ self.test_client_disconnect(close_before_start=True)
+
+ def test_lookahead_continue(self):
+ """
+ Send two requests to a channel with lookahead and use an
+ expect-continue on the second one, making sure the responses still come
+ in the right order.
+ """
+ self._make_app_with_lookahead()
+ self._send(
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "",
+ "x",
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "Expect: 100-continue",
+ "",
+ )
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n"))
+
+ self.sock.send(b"x")
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ self.channel._flush_some()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertEqual(data.split("\r\n")[-1], "finished")
+ self.assertEqual(self.request_body, b"x")
+
+
class DummySock:
blocking = False
closed = False
@@ -731,6 +855,11 @@ class DummySock:
self.sent += data
return len(data)
+ def recv(self, buffer_size):
+ result = self.sent[:buffer_size]
+ self.sent = self.sent[buffer_size:]
+ return result
+
class DummyLock:
notified = False
@@ -796,11 +925,16 @@ class DummyAdjustments:
expose_tracebacks = True
ident = "waitress"
max_request_header_size = 10000
+ url_prefix = ""
+ channel_request_lookahead = 0
+ max_request_body_size = 1048576
class DummyServer:
trigger_pulled = False
adj = DummyAdjustments()
+ effective_port = 8080
+ server_name = ""
def __init__(self):
self.tasks = []
diff --git a/tests/test_task.py b/tests/test_task.py
index de800fb..ea71e02 100644
--- a/tests/test_task.py
+++ b/tests/test_task.py
@@ -802,6 +802,7 @@ class TestWSGITask(unittest.TestCase):
"SERVER_PORT",
"SERVER_PROTOCOL",
"SERVER_SOFTWARE",
+ "waitress.client_disconnected",
"wsgi.errors",
"wsgi.file_wrapper",
"wsgi.input",
@@ -958,6 +959,10 @@ class DummyChannel:
creation_time = 0
addr = ("127.0.0.1", 39830)
+ def check_client_disconnected(self):
+ # For now, until we have tests handling this feature
+ return False
+
def __init__(self, server=None):
if server is None:
server = DummyServer()