summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBert JW Regeer <xistence@0x58.com>2020-10-31 00:03:50 -0700
committerGitHub <noreply@github.com>2020-10-31 00:03:50 -0700
commit31d7498c84cf0041f37beb503fd0ddf78d9d41e2 (patch)
treec2ce19d7fef77c748fb7246c28a951f7829c6a9e
parent5faf6989f4ff019940f41e2ee4855615f41afcc5 (diff)
parenta514a47e55f0a663477831d083f608a7d2db035b (diff)
downloadwaitress-31d7498c84cf0041f37beb503fd0ddf78d9d41e2.tar.gz
Merge pull request #310 from perfact/notify-client-close
Notify client close
-rw-r--r--CHANGES.txt14
-rw-r--r--src/waitress/adjustments.py8
-rw-r--r--src/waitress/channel.py277
-rw-r--r--src/waitress/runner.py6
-rw-r--r--src/waitress/task.py5
-rw-r--r--tests/test_channel.py158
-rw-r--r--tests/test_task.py5
7 files changed, 340 insertions, 133 deletions
diff --git a/CHANGES.txt b/CHANGES.txt
index c62602e..8dd82b5 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,5 +1,19 @@
2.0.0 (unreleased)
------------------
+- Allow tasks to notice if the client disconnected.
+
+ This inserts a callable `waitress.client_disconnected` into the environment
+ that allows the task to check if the client disconnected while waiting for
+ the response at strategic points in the execution and to cancel the
+ operation.
+
+ It requires setting the new adjustment `channel_request_lookahead` to a value
+ larger than 0, which continues to read requests from a channel even if a
+ request is already being processed on that channel, up to the given count,
+ since a client disconnect is detected by reading from a readable socket and
+ receiving an empty result.
+
+ See https://github.com/Pylons/waitress/pull/310
- Drop Python 2.7 support
diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py
index 45ac41b..42d2bc0 100644
--- a/src/waitress/adjustments.py
+++ b/src/waitress/adjustments.py
@@ -135,6 +135,7 @@ class Adjustments:
("unix_socket", str),
("unix_socket_perms", asoctal),
("sockets", as_socket_list),
+ ("channel_request_lookahead", int),
)
_param_map = dict(_params)
@@ -280,6 +281,13 @@ class Adjustments:
# be used for e.g. socket activation
sockets = []
+ # By setting this to a value larger than zero, each channel stays readable
+ # and continues to read requests from the client even if a request is still
+ # running, until the number of buffered requests exceeds this value.
+ # This allows detecting if a client closed the connection while its request
+ # is being processed.
+ channel_request_lookahead = 0
+
def __init__(self, **kw):
if "listen" in kw and ("host" in kw or "port" in kw):
diff --git a/src/waitress/channel.py b/src/waitress/channel.py
index 65bc87f..296a16a 100644
--- a/src/waitress/channel.py
+++ b/src/waitress/channel.py
@@ -40,11 +40,11 @@ class HTTPChannel(wasyncore.dispatcher):
error_task_class = ErrorTask
parser_class = HTTPRequestParser
- request = None # A request parser instance
+ # A request that has not been received yet completely is stored here
+ request = None
last_activity = 0 # Time of last activity
will_close = False # set to True to close the socket.
close_when_flushed = False # set to True to close the socket when flushed
- requests = () # currently pending requests
sent_continue = False # used as a latch after sending 100 continue
total_outbufs_len = 0 # total bytes ready to send
current_outbuf_count = 0 # total bytes written to current outbuf
@@ -60,8 +60,9 @@ class HTTPChannel(wasyncore.dispatcher):
self.creation_time = self.last_activity = time.time()
self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
- # task_lock used to push/pop requests
- self.task_lock = threading.Lock()
+ # requests_lock used to push/pop requests and modify the request that is
+ # currently being created
+ self.requests_lock = threading.Lock()
# outbuf_lock used to access any outbuf (expected to use an RLock)
self.outbuf_lock = threading.Condition()
@@ -69,6 +70,15 @@ class HTTPChannel(wasyncore.dispatcher):
# Don't let wasyncore.dispatcher throttle self.addr on us.
self.addr = addr
+ self.requests = []
+
+ def check_client_disconnected(self):
+ """
+ This method is inserted into the environment of any created task so it
+ may occasionally check if the client has disconnected and interrupt
+ execution.
+ """
+ return not self.connected
def writable(self):
# if there's data in the out buffer or we've been instructed to close
@@ -125,18 +135,18 @@ class HTTPChannel(wasyncore.dispatcher):
self.handle_close()
def readable(self):
- # We might want to create a new task. We can only do this if:
+ # We might want to read more requests. We can only do this if:
# 1. We're not already about to close the connection.
# 2. We're not waiting to flush remaining data before closing the
# connection
- # 3. There's no already currently running task(s).
+ # 3. There are not too many tasks already queued
# 4. There's no data in the output buffer that needs to be sent
# before we potentially create a new task.
return not (
self.will_close
or self.close_when_flushed
- or self.requests
+ or len(self.requests) > self.adj.channel_request_lookahead
or self.total_outbufs_len
)
@@ -153,57 +163,69 @@ class HTTPChannel(wasyncore.dispatcher):
if data:
self.last_activity = time.time()
self.received(data)
+ else:
+ # Client disconnected.
+ self.connected = False
+
+ def send_continue(self):
+ """
+ Send a 100-Continue header to the client. This is either called from
+ receive (if no requests are running and the client expects it) or at
+ the end of service (if no more requests are queued and a request has
+ been read partially that expects it).
+ """
+ self.request.expect_continue = False
+ outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
+ num_bytes = len(outbuf_payload)
+ with self.outbuf_lock:
+ self.outbufs[-1].append(outbuf_payload)
+ self.current_outbuf_count += num_bytes
+ self.total_outbufs_len += num_bytes
+ self.sent_continue = True
+ self._flush_some()
+ self.request.completed = False
def received(self, data):
"""
Receives input asynchronously and assigns one or more requests to the
channel.
"""
- # Preconditions: there's no task(s) already running
- request = self.request
- requests = []
-
if not data:
return False
- while data:
- if request is None:
- request = self.parser_class(self.adj)
- n = request.received(data)
-
- if request.expect_continue and request.headers_finished:
- # guaranteed by parser to be a 1.1 request
- request.expect_continue = False
-
- if not self.sent_continue:
- # there's no current task, so we don't need to try to
- # lock the outbuf to append to it.
- outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
- num_bytes = len(outbuf_payload)
- self.outbufs[-1].append(outbuf_payload)
- self.current_outbuf_count += num_bytes
- self.total_outbufs_len += num_bytes
- self.sent_continue = True
- self._flush_some()
- request.completed = False
-
- if request.completed:
- # The request (with the body) is ready to use.
- self.request = None
-
- if not request.empty:
- requests.append(request)
- request = None
- else:
- self.request = request
-
- if n >= len(data):
- break
- data = data[n:]
-
- if requests:
- self.requests = requests
- self.server.add_task(self)
+ with self.requests_lock:
+ while data:
+ if self.request is None:
+ self.request = self.parser_class(self.adj)
+ n = self.request.received(data)
+
+ # if there are requests queued, we can not send the continue
+ # header yet since the responses need to be kept in order
+ if (
+ self.request.expect_continue
+ and self.request.headers_finished
+ and not self.requests
+ and not self.sent_continue
+ ):
+ self.send_continue()
+
+ if self.request.completed:
+ # The request (with the body) is ready to use.
+ self.sent_continue = False
+
+ if not self.request.empty:
+ self.requests.append(self.request)
+ if len(self.requests) == 1:
+ # self.requests was empty before so the main thread
+ # is in charge of starting the task. Otherwise,
+ # service() will add a new task after each request
+ # has been processed
+ self.server.add_task(self)
+ self.request = None
+
+ if n >= len(data):
+ break
+ data = data[n:]
return True
@@ -360,88 +382,101 @@ class HTTPChannel(wasyncore.dispatcher):
self.outbuf_lock.wait()
def service(self):
- """Execute all pending requests """
- with self.task_lock:
- while self.requests:
- request = self.requests[0]
+ """Execute one request. If there are more, we add another task to the
+ server at the end."""
+
+ request = self.requests[0]
+
+ if request.error:
+ task = self.error_task_class(self, request)
+ else:
+ task = self.task_class(self, request)
- if request.error:
- task = self.error_task_class(self, request)
+ try:
+ if self.connected:
+ task.service()
+ else:
+ task.close_on_finish = True
+ except ClientDisconnected:
+ self.logger.info("Client disconnected while serving %s" % task.request.path)
+ task.close_on_finish = True
+ except Exception:
+ self.logger.exception("Exception while serving %s" % task.request.path)
+
+ if not task.wrote_header:
+ if self.adj.expose_tracebacks:
+ body = traceback.format_exc()
else:
- task = self.task_class(self, request)
+ body = "The server encountered an unexpected internal server error"
+ req_version = request.version
+ req_headers = request.headers
+ err_request = self.parser_class(self.adj)
+ err_request.error = InternalServerError(body)
+ # copy some original request attributes to fulfill
+ # HTTP 1.1 requirements
+ err_request.version = req_version
try:
- task.service()
+ err_request.headers["CONNECTION"] = req_headers["CONNECTION"]
+ except KeyError:
+ pass
+ task = self.error_task_class(self, err_request)
+ try:
+ task.service() # must not fail
except ClientDisconnected:
- self.logger.info(
- "Client disconnected while serving %s" % task.request.path
- )
task.close_on_finish = True
- except Exception:
- self.logger.exception(
- "Exception while serving %s" % task.request.path
- )
+ else:
+ task.close_on_finish = True
- if not task.wrote_header:
- if self.adj.expose_tracebacks:
- body = traceback.format_exc()
- else:
- body = (
- "The server encountered an unexpected "
- "internal server error"
- )
- req_version = request.version
- req_headers = request.headers
- request = self.parser_class(self.adj)
- request.error = InternalServerError(body)
- # copy some original request attributes to fulfill
- # HTTP 1.1 requirements
- request.version = req_version
- try:
- request.headers["CONNECTION"] = req_headers["CONNECTION"]
- except KeyError:
- pass
- task = self.error_task_class(self, request)
- try:
- task.service() # must not fail
- except ClientDisconnected:
- task.close_on_finish = True
- else:
- task.close_on_finish = True
- # we cannot allow self.requests to drop to empty til
- # here; otherwise the mainloop gets confused
-
- if task.close_on_finish:
- self.close_when_flushed = True
-
- for request in self.requests:
- request.close()
- self.requests = []
- else:
- # before processing a new request, ensure there is not too
- # much data in the outbufs waiting to be flushed
- # NB: currently readable() returns False while we are
- # flushing data so we know no new requests will come in
- # that we need to account for, otherwise it'd be better
- # to do this check at the start of the request instead of
- # at the end to account for consecutive service() calls
-
- if len(self.requests) > 1:
- self._flush_outbufs_below_high_watermark()
-
- # this is a little hacky but basically it's forcing the
- # next request to create a new outbuf to avoid sharing
- # outbufs across requests which can cause outbufs to
- # not be deallocated regularly when a connection is open
- # for a long time
-
- if self.current_outbuf_count > 0:
- self.current_outbuf_count = self.adj.outbuf_high_watermark
-
- request = self.requests.pop(0)
+ if task.close_on_finish:
+ with self.requests_lock:
+ self.close_when_flushed = True
+
+ for request in self.requests:
request.close()
+ self.requests = []
+ else:
+ # before processing a new request, ensure there is not too
+ # much data in the outbufs waiting to be flushed
+ # NB: currently readable() returns False while we are
+ # flushing data so we know no new requests will come in
+ # that we need to account for, otherwise it'd be better
+ # to do this check at the start of the request instead of
+ # at the end to account for consecutive service() calls
+
+ if len(self.requests) > 1:
+ self._flush_outbufs_below_high_watermark()
+
+ # this is a little hacky but basically it's forcing the
+ # next request to create a new outbuf to avoid sharing
+ # outbufs across requests which can cause outbufs to
+ # not be deallocated regularly when a connection is open
+ # for a long time
+
+ if self.current_outbuf_count > 0:
+ self.current_outbuf_count = self.adj.outbuf_high_watermark
+
+ request.close()
+
+ # Add new task to process the next request
+ with self.requests_lock:
+ self.requests.pop(0)
+ if self.connected and self.requests:
+ self.server.add_task(self)
+ elif (
+ self.connected
+ and self.request is not None
+ and self.request.expect_continue
+ and self.request.headers_finished
+ and not self.sent_continue
+ ):
+ # A request waits for a signal to continue, but we could
+ # not send it until now because requests were being
+ # processed and the output needs to be kept in order
+ self.send_continue()
if self.connected:
self.server.pull_trigger()
+
self.last_activity = time.time()
def cancel(self):
diff --git a/src/waitress/runner.py b/src/waitress/runner.py
index 4fb3e6b..c23ca0e 100644
--- a/src/waitress/runner.py
+++ b/src/waitress/runner.py
@@ -169,6 +169,12 @@ Tuning options:
The use_poll argument passed to ``asyncore.loop()``. Helps overcome
open file descriptors limit. Default is False.
+ --channel-request-lookahead=INT
+ Allows channels to stay readable and buffer more requests up to the
+ given maximum even if a request is already being processed. This allows
+ detecting if a client closed the connection while its request is being
+ processed. Default is 0.
+
"""
RUNNER_PATTERN = re.compile(
diff --git a/src/waitress/task.py b/src/waitress/task.py
index 3a7cf17..2ac8f4c 100644
--- a/src/waitress/task.py
+++ b/src/waitress/task.py
@@ -560,6 +560,11 @@ class WSGITask(Task):
if mykey not in environ:
environ[mykey] = value
+ # Insert a callable into the environment that allows the application to
+ # check if the client disconnected. Only works with
+ # channel_request_lookahead larger than 0.
+ environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
+
# cache the environ for this request
self.environ = environ
return environ
diff --git a/tests/test_channel.py b/tests/test_channel.py
index df3d450..d86dbbe 100644
--- a/tests/test_channel.py
+++ b/tests/test_channel.py
@@ -1,6 +1,8 @@
import io
import unittest
+import pytest
+
class TestHTTPChannel(unittest.TestCase):
def _makeOne(self, sock, addr, adj, map=None):
@@ -173,7 +175,7 @@ class TestHTTPChannel(unittest.TestCase):
def test_readable_with_requests(self):
inst, sock, map = self._makeOneWithMap()
- inst.requests = True
+ inst.requests = [True]
self.assertEqual(inst.readable(), False)
def test_handle_read_no_error(self):
@@ -189,8 +191,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(L, [b"abc"])
def test_handle_read_error(self):
- import socket
-
inst, sock, map = self._makeOneWithMap()
inst.will_close = False
@@ -439,7 +439,7 @@ class TestHTTPChannel(unittest.TestCase):
preq.completed = False
preq.empty = True
inst.received(b"GET / HTTP/1.1\r\n\r\n")
- self.assertEqual(inst.requests, ())
+ self.assertEqual(inst.requests, [])
self.assertEqual(inst.server.tasks, [])
def test_received_preq_completed_empty(self):
@@ -525,14 +525,6 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.sent_continue, True)
self.assertEqual(preq.completed, False)
- def test_service_no_requests(self):
- inst, sock, map = self._makeOneWithMap()
- inst.requests = []
- inst.service()
- self.assertEqual(inst.requests, [])
- self.assertTrue(inst.server.trigger_pulled)
- self.assertTrue(inst.last_activity)
-
def test_service_with_one_request(self):
inst, sock, map = self._makeOneWithMap()
request = DummyRequest()
@@ -561,6 +553,7 @@ class TestHTTPChannel(unittest.TestCase):
inst.task_class = DummyTaskClass()
inst.requests = [request1, request2]
inst.service()
+ inst.service()
self.assertEqual(inst.requests, [])
self.assertTrue(request1.serviced)
self.assertTrue(request2.serviced)
@@ -705,6 +698,137 @@ class TestHTTPChannel(unittest.TestCase):
self.assertEqual(inst.requests, [])
+class TestHTTPChannelLookahead(TestHTTPChannel):
+ def app_check_disconnect(self, environ, start_response):
+ """
+ Application that checks for client disconnection every
+ second for up to two seconds.
+ """
+ import time
+
+ if hasattr(self, "app_started"):
+ self.app_started.set()
+
+ try:
+ request_body_size = int(environ.get("CONTENT_LENGTH", 0))
+ except ValueError:
+ request_body_size = 0
+ self.request_body = environ["wsgi.input"].read(request_body_size)
+
+ self.disconnect_detected = False
+ check = environ["waitress.client_disconnected"]
+ if environ["PATH_INFO"] == "/sleep":
+ for i in range(3):
+ if i != 0:
+ time.sleep(1)
+ if check():
+ self.disconnect_detected = True
+ break
+
+ body = b"finished"
+ cl = str(len(body))
+ start_response(
+ "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]
+ )
+ return [body]
+
+ def _make_app_with_lookahead(self):
+ """
+ Setup a channel with lookahead and store it and the socket in self
+ """
+ adj = DummyAdjustments()
+ adj.channel_request_lookahead = 5
+ channel, sock, map = self._makeOneWithMap(adj=adj)
+ channel.server.application = self.app_check_disconnect
+
+ self.channel = channel
+ self.sock = sock
+
+ def _send(self, *lines):
+ """
+ Send lines through the socket with correct line endings
+ """
+ self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii"))
+
+ def test_client_disconnect(self, close_before_start=False):
+ """Disconnect the socket after starting the task."""
+ import threading
+
+ self._make_app_with_lookahead()
+ self._send(
+ "GET /sleep HTTP/1.1",
+ "Host: localhost:8080",
+ "",
+ )
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.server.tasks), 1)
+ self.app_started = threading.Event()
+ self.disconnect_detected = False
+ thread = threading.Thread(target=self.channel.server.tasks[0].service)
+
+ if not close_before_start:
+ thread.start()
+ self.assertTrue(self.app_started.wait(timeout=5))
+
+ # Close the socket, check that the channel is still readable due to the
+ # lookahead and read it, which marks the channel as closed.
+ self.sock.close()
+ self.assertTrue(self.channel.readable())
+ self.channel.handle_read()
+
+ if close_before_start:
+ thread.start()
+
+ thread.join()
+
+ if close_before_start:
+ self.assertFalse(self.app_started.is_set())
+ else:
+ self.assertTrue(self.disconnect_detected)
+
+ def test_client_disconnect_immediate(self):
+ """
+ The same test, but this time we close the socket even before processing
+ started. The app should not be executed.
+ """
+ self.test_client_disconnect(close_before_start=True)
+
+ def test_lookahead_continue(self):
+ """
+ Send two requests to a channel with lookahead and use an
+ expect-continue on the second one, making sure the responses still come
+ in the right order.
+ """
+ self._make_app_with_lookahead()
+ self._send(
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "",
+ "x",
+ "POST / HTTP/1.1",
+ "Host: localhost:8080",
+ "Content-Length: 1",
+ "Expect: 100-continue",
+ "",
+ )
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n"))
+
+ self.sock.send(b"x")
+ self.channel.handle_read()
+ self.assertEqual(len(self.channel.requests), 1)
+ self.channel.server.tasks[0].service()
+ self.channel._flush_some()
+ data = self.sock.recv(256).decode("ascii")
+ self.assertEqual(data.split("\r\n")[-1], "finished")
+ self.assertEqual(self.request_body, b"x")
+
+
class DummySock:
blocking = False
closed = False
@@ -731,6 +855,11 @@ class DummySock:
self.sent += data
return len(data)
+ def recv(self, buffer_size):
+ result = self.sent[:buffer_size]
+ self.sent = self.sent[buffer_size:]
+ return result
+
class DummyLock:
notified = False
@@ -796,11 +925,16 @@ class DummyAdjustments:
expose_tracebacks = True
ident = "waitress"
max_request_header_size = 10000
+ url_prefix = ""
+ channel_request_lookahead = 0
+ max_request_body_size = 1048576
class DummyServer:
trigger_pulled = False
adj = DummyAdjustments()
+ effective_port = 8080
+ server_name = ""
def __init__(self):
self.tasks = []
diff --git a/tests/test_task.py b/tests/test_task.py
index de800fb..ea71e02 100644
--- a/tests/test_task.py
+++ b/tests/test_task.py
@@ -802,6 +802,7 @@ class TestWSGITask(unittest.TestCase):
"SERVER_PORT",
"SERVER_PROTOCOL",
"SERVER_SOFTWARE",
+ "waitress.client_disconnected",
"wsgi.errors",
"wsgi.file_wrapper",
"wsgi.input",
@@ -958,6 +959,10 @@ class DummyChannel:
creation_time = 0
addr = ("127.0.0.1", 39830)
+ def check_client_disconnected(self):
+ # For now, until we have tests handling this feature
+ return False
+
def __init__(self, server=None):
if server is None:
server = DummyServer()