summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViktor Dick <vd@perfact.de>2020-08-11 21:10:34 +0200
committerViktor Dick <vd@perfact.de>2020-10-25 09:36:34 +0100
commita514a47e55f0a663477831d083f608a7d2db035b (patch)
treee5df4d9cf7db0cf4e700fe53b959f5b9d8df485d
parent5570c7c9931d32f7558e08143d32125a745ec56b (diff)
downloadwaitress-a514a47e55f0a663477831d083f608a7d2db035b.tar.gz
Allow tasks to notice if client disconnected
This inserts a callable `waitress.client_disconnected` into the environment that allows the task to check if the client disconnected while waiting for the response at strategic points in the execution, allowing to cancel the operation. It requires setting the new adjustment `channel_request_lookahead` to a value larger than 0, which continues to read requests from a channel even if a request is already being processed on that channel, up to the given count, since a client disconnect is detected by reading from a readable socket and receiving an empty result.
-rw-r--r--CHANGES.txt14
-rw-r--r--src/waitress/adjustments.py8
-rw-r--r--src/waitress/channel.py277
-rw-r--r--src/waitress/runner.py6
-rw-r--r--src/waitress/task.py5
-rw-r--r--tests/test_channel.py158
-rw-r--r--tests/test_task.py5
7 files changed, 340 insertions, 133 deletions
diff --git a/CHANGES.txt b/CHANGES.txt
index f4d1acc..894ff94 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,5 +1,19 @@
2.0.0 (unreleased)
------------------
+- Allow tasks to notice if the client disconnected.
+
+ This inserts a callable `waitress.client_disconnected` into the environment
+ that allows the task to check if the client disconnected while waiting for
+ the response at strategic points in the execution and to cancel the
+ operation.
+
+ It requires setting the new adjustment `channel_request_lookahead` to a value
+ larger than 0, which continues to read requests from a channel even if a
+ request is already being processed on that channel, up to the given count,
+ since a client disconnect is detected by reading from a readable socket and
+ receiving an empty result.
+
+ See https://github.com/Pylons/waitress/pull/310
- Drop Python 2.7 support
diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py
index 45ac41b..42d2bc0 100644
--- a/src/waitress/adjustments.py
+++ b/src/waitress/adjustments.py
@@ -135,6 +135,7 @@ class Adjustments:
("unix_socket", str),
("unix_socket_perms", asoctal),
("sockets", as_socket_list),
+ ("channel_request_lookahead", int),
)
_param_map = dict(_params)
@@ -280,6 +281,13 @@ class Adjustments:
# be used for e.g. socket activation
sockets = []
+ # By setting this to a value larger than zero, each channel stays readable
+ # and continues to read requests from the client even if a request is still
+ # running, until the number of buffered requests exceeds this value.
+ # This allows detecting if a client closed the connection while its request
+ # is being processed.
+ channel_request_lookahead = 0
+
def __init__(self, **kw):
if "listen" in kw and ("host" in kw or "port" in kw):
diff --git a/src/waitress/channel.py b/src/waitress/channel.py
index 65bc87f..296a16a 100644
--- a/src/waitress/channel.py
+++ b/src/waitress/channel.py
@@ -40,11 +40,11 @@ class HTTPChannel(wasyncore.dispatcher):
error_task_class = ErrorTask
parser_class = HTTPRequestParser
- request = None # A request parser instance
+ # A request that has not been received yet completely is stored here
+ request = None
last_activity = 0 # Time of last activity
will_close = False # set to True to close the socket.
close_when_flushed = False # set to True to close the socket when flushed
- requests = () # currently pending requests
sent_continue = False # used as a latch after sending 100 continue
total_outbufs_len = 0 # total bytes ready to send
current_outbuf_count = 0 # total bytes written to current outbuf
@@ -60,8 +60,9 @@ class HTTPChannel(wasyncore.dispatcher):
self.creation_time = self.last_activity = time.time()
self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
- # task_lock used to push/pop requests
- self.task_lock = threading.Lock()
+ # requests_lock used to push/pop requests and modify the request that is
+ # currently being created
+ self.requests_lock = threading.Lock()
# outbuf_lock used to access any outbuf (expected to use an RLock)
self.outbuf_lock = threading.Condition()
@@ -69,6 +70,15 @@ class HTTPChannel(wasyncore.dispatcher):
# Don't let wasyncore.dispatcher throttle self.addr on us.
self.addr = addr
+ self.requests = []
+
+ def check_client_disconnected(self):
+ """
+ This method is inserted into the environment of any created task so it
+ may occasionally check if the client has disconnected and interrupt
+ execution.
+ """
+ return not self.connected
def writable(self):
# if there's data in the out buffer or we've been instructed to close
@@ -125,18 +135,18 @@ class HTTPChannel(wasyncore.dispatcher):
self.handle_close()
def readable(self):
- # We might want to create a new task. We can only do this if:
+ # We might want to read more requests. We can only do this if:
# 1. We're not already about to close the connection.
# 2. We're not waiting to flush remaining data before closing the
# connection
- # 3. There's no already currently running task(s).
+ # 3. There are not too many tasks already queued
# 4. There's no data in the output buffer that needs to be sent
# before we potentially create a new task.
return not (
self.will_close
or self.close_when_flushed
- or self.requests
+ or len(self.requests) > self.adj.channel_request_lookahead
or self.total_outbufs_len
)
@@ -153,57 +163,69 @@ class HTTPChannel(wasyncore.dispatcher):
if data:
self.last_activity = time.time()
self.received(data)
+ else:
+ # Client disconnected.
+ self.connected = False
+
+ def send_continue(self):
+ """
+ Send a 100-Continue header to the client. This is either called from
+ receive (if no requests are running and the client expects it) or at
+ the end of service (if no more requests are queued and a request has
+ been read partially that expects it).
+ """
+ self.request.expect_continue = False
+ outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
+ num_bytes = len(outbuf_payload)
+ with self.outbuf_lock:
+ self.outbufs[-1].append(outbuf_payload)
+ self.current_outbuf_count += num_bytes
+ self.total_outbufs_len += num_bytes
+ self.sent_continue = True
+ self._flush_some()
+ self.request.completed = False
def received(self, data):
"""
Receives input asynchronously and assigns one or more requests to the
channel.
"""
- # Preconditions: there's no task(s) already running
- request = self.request
- requests = []
-
if not data:
return False
- while data:
- if request is None:
- request = self.parser_class(self.adj)
- n = request.received(data)
-
- if request.expect_continue and request.headers_finished:
- # guaranteed by parser to be a 1.1 request
- request.expect_continue = False
-
- if not self.sent_continue:
- # there's no current task, so we don't need to try to
- # lock the outbuf to append to it.
- outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
- num_bytes = len(outbuf_payload)
- self.outbufs[-1].append(outbuf_payload)
- self.current_outbuf_count += num_bytes
- self.total_outbufs_len += num_bytes
- self.sent_continue = True
- self._flush_some()
- request.completed = False
-
- if request.completed:
- # The request (with the body) is ready to use.
- self.request = None
-
- if not request.empty:
- requests.append(request)
- request = None
- else:
- self.request = request
-
- if n >= len(data):
- break
- data = data[n:]
-
- if requests:
- self.requests = requests
- self.server.add_task(self)
+ with self.requests_lock:
+ while data:
+ if self.request is None:
+ self.request = self.parser_class(self.adj)
+ n = self.request.received(data)
+
+ # if there are requests queued, we can not send the continue
+ # header yet since the responses need to be kept in order
+ if (
+ self.request.expect_continue
+ and self.request.headers_finished
+ and not self.requests
+ and not self.sent_continue
+ ):
+ self.send_continue()
+
+ if self.request.completed:
+ # The request (with the body) is ready to use.
+ self.sent_continue = False
+
+ if not self.request.empty:
+ self.requests.append(self.request)
+ if len(self.requests) == 1:
+ # self.requests was empty before so the main thread
+ # is in charge of starting the task. Otherwise,
+ # service() will add a new task after each request
+ # has been processed
+ self.server.add_task(self)
+ self.request = None
+
+ if n >= len(data):
+ break
+ data = data[n:]
return True
@@ -360,88 +382,101 @@ class HTTPChannel(wasyncore.dispatcher):
self.outbuf_lock.wait()
def service(self):
- """Execute all pending requests """
- with self.task_lock:
- while self.requests:
- request = self.requests[0]
+ """Execute one request. If there are more, we add another task to the
+ server at the end."""
+
+ request = self.requests[0]
+
+ if request.error:
+ task = self.error_task_class(self, request)
+ else:
+ task = self.task_class(self, request)
- if request.error:
- task = self.error_task_class(self, request)
+ try:
+ if self.connected:
+ task.service()
+ else:
+ task.close_on_finish = True
+ except ClientDisconnected:
+ self.logger.info("Client disconnected while serving %s" % task.request.path)
+ task.close_on_finish = True
+ except Exception:
+ self.logger.exception("Exception while serving %s" % task.request.path)
+
+ if not task.wrote_header:
+ if self.adj.expose_tracebacks:
+ body = traceback.format_exc()
else:
- task = self.task_class(self, request)
+ body = "The server encountered an unexpected internal server error"
+ req_version = request.version
+ req_headers = request.headers
+ err_request = self.parser_class(self.adj)
+ err_request.error = InternalServerError(body)
+ # copy some original request attributes to fulfill
+ # HTTP 1.1 requirements
+ err_request.version = req_version
try:
- task.service()
+ err_request.headers["CONNECTION"] = req_headers["CONNECTION"]
+ except KeyError:
+ pass
+ task = self.error_task_class(self, err_request)
+ try:
+ task.service() # must not fail
except ClientDisconnected:
- self.logger.info(
- "Client disconnected while serving %s" % task.request.path
- )
task.close_on_finish = True
- except Exception:
- self.logger.exception(
- "Exception while serving %s" % task.request.path
- )
+ else:
+ task.close_on_finish = True
- if not task.wrote_header:
- if self.adj.expose_tracebacks:
- body = traceback.format_exc()
- else:
- body = (
- "The server encountered an unexpected "
- "internal server error"
- )
- req_version = request.version
- req_headers = request.headers
- request = self.parser_class(self.adj)
- request.error = InternalServerError(body)
- # copy some original request attributes to fulfill
- # HTTP 1.1 requirements
- request.version = req_version
- try:
- request.headers["CONNECTION"] = req_headers["CONNECTION"]
- except KeyError:
- pass
- task = self.error_task_class(self, request)
- try:
- task.service() # must not fail
- except ClientDisconnected:
- task.close_on_finish = True
- else:
- task.close_on_finish = True
- # we cannot allow self.requests to drop to empty til
- # here; otherwise the mainloop gets confused
-
- if task.close_on_finish:
- self.close_when_flushed = True
-
- for request in self.requests:
- request.close()
- self.requests = []
- else:
- # before processing a new request, ensure there is not too
- # much data in the outbufs waiting to be flushed
- # NB: currently readable() returns False while we are
- # flushing data so we know no new requests will come in
- # that we need to account for, otherwise it'd be better
- # to do this check at the start of the request instead of
- # at the end to account for consecutive service() calls
-
- if len(self.requests) > 1:
- self._flush_outbufs_below_high_watermark()
-
- # this is a little hacky but basically it's forcing the
- # next request to create a new outbuf to avoid sharing
- # outbufs across requests which can cause outbufs to
- # not be deallocated regularly when a connection is open
- # for a long time
-
- if self.current_outbuf_count > 0:
- self.current_outbuf_count = self.adj.outbuf_high_watermark
-
- request = self.requests.pop(0)
+ if task.close_on_finish:
+ with self.requests_lock:
+ self.close_when_flushed = True
+
+ for request in self.requests:
request.close()
+ self.requests = []
+ else:
+ # before processing a new request, ensure there is not too
+ # much data in the outbufs waiting to be flushed
+ # NB: currently readable() returns False while we are
+ # flushing data so we know no new requests will come in
+ # that we need to account for, otherwise it'd be better
+ # to do this check at the start of the request instead of
+ # at the end to account for consecutive service() calls
+
+ if len(self.requests) > 1:
+ self._flush_outbufs_below_high_watermark()
+
+ # this is a little hacky but basically it's forcing the
+ # next request to create a new outbuf to avoid sharing
+ # outbufs across requests which can cause outbufs to
+ # not be deallocated regularly when a connection is open
+ # for a long time
+
+ if self.current_outbuf_count > 0:
+ self.current_outbuf_count = self.adj.outbuf_high_watermark
+
+ request.close()
+
+ # Add new task to process the next request
+ with self.requests_lock:
+ self.requests.pop(0)
+ if self.connected and self.requests:
+ self.server.add_task(self)
+ elif (
+ self.connected
+ and self.request is not None
+ and self.request.expect_continue
+ and self.request.headers_finished
+ and not self.sent_continue
+ ):
+ # A request waits for a signal to continue, but we could
+ # not send it until now because requests were being
+ # processed and the output needs to be kept in order
+ self.send_continue()
if self.connected:
self.server.pull_trigger()
+
self.last_activity = time.time()
def cancel(self):
diff --git a/src/waitress/runner.py b/src/waitress/runner.py
index 4fb3e6b..c23ca0e 100644
--- a/src/waitress/runner.py
+++ b/src/waitress/runner.py
@@ -169,6 +169,12 @@ Tuning options:
The use_poll argument passed to ``asyncore.loop()``. Helps overcome
open file descriptors limit. Default is False.
+ --channel-request-lookahead=INT
+ Allows channels to stay readable and buffer more requests up to the
+ given maximum even if a request is already being processed. This allows
+ detecting if a client closed the connection while its request is being
+ processed. Default is 0.
+
"""
RUNNER_PATTERN = re.compile(
diff --git a/src/waitress/task.py b/src/waitress/task.py
index 3a7cf17..2ac8f4c 100644
--- a/src/waitress/task.py
+++ b/src/waitress/task.py
@@ -560,6 +560,11 @@ class WSGITask(Task):
if mykey not in environ:
environ[mykey] = value
+ # Insert a callable into the environment that allows the application to
+ # check if the client disconnected. Only works with
+ # channel_request_lookahead larger than 0.
+ environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
+
# cache the environ for this request
self.environ = environ
return environ
diff --git a/tests/test_channel.py b/tests/test_channel.py
index df3d450..d86dbbe 100644
--- a/tests/test_channel.py
+++ b/tests/test_channel.py
@@ -1,6 +1,8 @@
import io
import unittest
+import pytest
+
class TestHTTPChannel(unittest.TestCase):
def _makeOne(self, sock, addr, adj, map=None):
@@ -173,7 +175,7 @@ class TestHTTPChannel(unittest.TestCase):
def test_readable_with_requests(self):
inst, sock, map = self._makeOneWithMap()
- inst.requests = True
+ inst.requests = [True]
self.assertEqual(inst.readable(), False)
def test_handle_read_no_error(self):
@@ -189,8 +191,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(L, [b"abc"])
def test_handle_read_error(self):
- import socket
-
inst, sock, map = self._makeOneWithMap()
inst.will_close = False
@@ -439,7 +439,7 @@ class TestHTTPChannel(unittest.TestCase):
preq.completed = False
preq.empty = True
inst.received(b"GET / HTTP/1.1\r\n\r\n")
- self.assertEqual(inst.requests, ())
+ self.assertEqual(inst.requests, [])
self.assertEqual(inst.server.tasks, [])
def test_received_preq_completed_empty(self):
@@ -525,14 +525,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.sent_continue, True)
self.assertEqual(preq.completed, False)
- def test_service_no_requests(self):
- inst, sock, map = self._makeOneWithMap()
- inst.requests = []
- inst.service()
- self.assertEqual(inst.requests, [])
- self.assertTrue(inst.server.trigger_pulled)
- self.assertTrue(inst.last_activity)
-
def test_service_with_one_request(self):
inst, sock, map = self._makeOneWithMap()
request = DummyRequest()
@@ -561,6 +553,7 @@ class TestHTTPChannel(unittest.TestCase):
inst.task_class = DummyTaskClass()
inst.requests = [request1, request2]
inst.service()
+ inst.service()
self.assertEqual(inst.requests, [])
self.assertTrue(request1.serviced)
self.assertTrue(request2.serviced)
@@ -705,6 +698,137 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.requests, [])
+class TestHTTPChannelLookahead(TestHTTPChannel):
+ def app_check_disconnect(self, environ, start_response):
+ """
+ Application that checks for client disconnection every
+ second for up to two seconds.
+ """
+ import time
+
+ if hasattr(self, "app_started"):
+ self.app_started.set()
+
+ try:
+ request_body_size = int(environ.get("CONTENT_LENGTH", 0))
+ except ValueError:
+ request_body_size = 0
+ self.request_body = environ["wsgi.input"].read(request_body_size)
+
+ self.disconnect_detected = False
+ check = environ["waitress.client_disconnected"]
+ if environ["PATH_INFO"] == "/sleep":
+ for i in range(3):
+ if i != 0:
+ time.sleep(1)
+ if check():
+ self.disconnect_detected = True
+ break
+
+ body = b"finished"
+ cl = str(len(body))
+ start_response(
+ "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]
+ )
+ return [body]
+
+ def _make_app_with_lookahead(self):
+ """
+ Setup a channel with lookahead and store it and the socket in self
+ """
+ adj = DummyAdjustments()
+ adj.channel_request_lookahead = 5
+ channel, sock, map = self._makeOneWithMap(adj=adj)
+ channel.server.application = self.app_check_disconnect
+
+ self.channel = channel
+ self.sock = sock
+
+ def _send(self, *lines):
+ """
+ Send lines through the socket with correct line endings
+ """
+ self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii"))
+
+ def test_client_disconnect(self, close_before_start=False):
+ """Disconnect the socket after starting the task."""
+ import threading
+
+ self._make_app_with_lookahead()
+ self._send(
+ "GET /sleep HTTP/1.1",
+ "Host: localhost:8080",
+ "",
+ )
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.server.tasks), 1)
+ self.app_started = threading.Event()
+ self.disconnect_detected = False
+ thread = threading.Thread(target=self.channel.server.tasks[0].service)
+
+ if not close_before_start:
+ thread.start()
+ self.assertTrue(self.app_started.wait(timeout=5))
+
+ # Close the socket, check that the channel is still readable due to the
+ # lookahead and read it, which marks the channel as closed.
+ self.sock.close()
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+
+ if close_before_start:
+ thread.start()
+
+ thread.join()
+
+ if close_before_start:
+ self.assertFalse(self.app_started.is_set())
+ else:
+ self.assertTrue(self.disconnect_detected)
+
+ def test_client_disconnect_immediate(self):
+ """
+ The same test, but this time we close the socket even before processing
+ started. The app should not be executed.
+ """
+ self.test_client_disconnect(close_before_start=True)
+
+ def test_lookahead_continue(self):
+ """
+ Send two requests to a channel with lookahead and use an
+ expect-continue on the second one, making sure the responses still come
+ in the right order.
+ """
+ self._make_app_with_lookahead()
+ self._send(
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "",
+ "x",
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "Expect: 100-continue",
+ "",
+ )
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n"))
+
+ self.sock.send(b"x")
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ self.channel._flush_some()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertEqual(data.split("\r\n")[-1], "finished")
+ self.assertEqual(self.request_body, b"x")
+
+
class DummySock:
blocking = False
closed = False
@@ -731,6 +855,11 @@ class DummySock:
self.sent += data
return len(data)
+ def recv(self, buffer_size):
+ result = self.sent[:buffer_size]
+ self.sent = self.sent[buffer_size:]
+ return result
+
class DummyLock:
notified = False
@@ -796,11 +925,16 @@ class DummyAdjustments:
expose_tracebacks = True
ident = "waitress"
max_request_header_size = 10000
+ url_prefix = ""
+ channel_request_lookahead = 0
+ max_request_body_size = 1048576
class DummyServer:
trigger_pulled = False
adj = DummyAdjustments()
+ effective_port = 8080
+ server_name = ""
def __init__(self):
self.tasks = []
diff --git a/tests/test_task.py b/tests/test_task.py
index de800fb..ea71e02 100644
--- a/tests/test_task.py
+++ b/tests/test_task.py
@@ -802,6 +802,7 @@ class TestWSGITask(unittest.TestCase):
"SERVER_PORT",
"SERVER_PROTOCOL",
"SERVER_SOFTWARE",
+ "waitress.client_disconnected",
"wsgi.errors",
"wsgi.file_wrapper",
"wsgi.input",
@@ -958,6 +959,10 @@ class DummyChannel:
creation_time = 0
addr = ("127.0.0.1", 39830)
+ def check_client_disconnected(self):
+ # For now, until we have tests handling this feature
+ return False
+
def __init__(self, server=None):
if server is None:
server = DummyServer()