diff options
-rw-r--r-- | CHANGES.txt | 14 | ||||
-rw-r--r-- | src/waitress/adjustments.py | 8 | ||||
-rw-r--r-- | src/waitress/channel.py | 277 | ||||
-rw-r--r-- | src/waitress/runner.py | 6 | ||||
-rw-r--r-- | src/waitress/task.py | 5 | ||||
-rw-r--r-- | tests/test_channel.py | 158 | ||||
-rw-r--r-- | tests/test_task.py | 5 |
7 files changed, 340 insertions, 133 deletions
diff --git a/CHANGES.txt b/CHANGES.txt index f4d1acc..894ff94 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,5 +1,19 @@ 2.0.0 (unreleased) ------------------ +- Allow tasks to notice if the client disconnected. + + This inserts a callable `waitress.client_disconnected` into the environment + that allows the task to check if the client disconnected while waiting for + the response at strategic points in the execution and to cancel the + operation. + + It requires setting the new adjustment `channel_request_lookahead` to a value + larger than 0, which continues to read requests from a channel even if a + request is already being processed on that channel, up to the given count, + since a client disconnect is detected by reading from a readable socket and + receiving an empty result. + + See https://github.com/Pylons/waitress/pull/310 - Drop Python 2.7 support diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py index 45ac41b..42d2bc0 100644 --- a/src/waitress/adjustments.py +++ b/src/waitress/adjustments.py @@ -135,6 +135,7 @@ class Adjustments: ("unix_socket", str), ("unix_socket_perms", asoctal), ("sockets", as_socket_list), + ("channel_request_lookahead", int), ) _param_map = dict(_params) @@ -280,6 +281,13 @@ class Adjustments: # be used for e.g. socket activation sockets = [] + # By setting this to a value larger than zero, each channel stays readable + # and continues to read requests from the client even if a request is still + # running, until the number of buffered requests exceeds this value. + # This allows detecting if a client closed the connection while its request + # is being processed. + channel_request_lookahead = 0 + def __init__(self, **kw): if "listen" in kw and ("host" in kw or "port" in kw): diff --git a/src/waitress/channel.py b/src/waitress/channel.py index 65bc87f..296a16a 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -40,11 +40,11 @@ class HTTPChannel(wasyncore.dispatcher): error_task_class = ErrorTask parser_class = HTTPRequestParser - request = None # A request parser instance + # A request that has not been received yet completely is stored here + request = None last_activity = 0 # Time of last activity will_close = False # set to True to close the socket. close_when_flushed = False # set to True to close the socket when flushed - requests = () # currently pending requests sent_continue = False # used as a latch after sending 100 continue total_outbufs_len = 0 # total bytes ready to send current_outbuf_count = 0 # total bytes written to current outbuf @@ -60,8 +60,9 @@ class HTTPChannel(wasyncore.dispatcher): self.creation_time = self.last_activity = time.time() self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) - # task_lock used to push/pop requests - self.task_lock = threading.Lock() + # requests_lock used to push/pop requests and modify the request that is + # currently being created + self.requests_lock = threading.Lock() # outbuf_lock used to access any outbuf (expected to use an RLock) self.outbuf_lock = threading.Condition() @@ -69,6 +70,15 @@ class HTTPChannel(wasyncore.dispatcher): # Don't let wasyncore.dispatcher throttle self.addr on us. self.addr = addr + self.requests = [] + + def check_client_disconnected(self): + """ + This method is inserted into the environment of any created task so it + may occasionally check if the client has disconnected and interrupt + execution. + """ + return not self.connected def writable(self): # if there's data in the out buffer or we've been instructed to close @@ -125,18 +135,18 @@ class HTTPChannel(wasyncore.dispatcher): self.handle_close() def readable(self): - # We might want to create a new task. We can only do this if: + # We might want to read more requests. We can only do this if: # 1. We're not already about to close the connection. # 2. We're not waiting to flush remaining data before closing the # connection - # 3. There's no already currently running task(s). + # 3. There are not too many tasks already queued # 4. There's no data in the output buffer that needs to be sent # before we potentially create a new task. return not ( self.will_close or self.close_when_flushed - or self.requests + or len(self.requests) > self.adj.channel_request_lookahead or self.total_outbufs_len ) @@ -153,57 +163,69 @@ class HTTPChannel(wasyncore.dispatcher): if data: self.last_activity = time.time() self.received(data) + else: + # Client disconnected. + self.connected = False + + def send_continue(self): + """ + Send a 100-Continue header to the client. This is either called from + receive (if no requests are running and the client expects it) or at + the end of service (if no more requests are queued and a request has + been read partially that expects it). + """ + self.request.expect_continue = False + outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" + num_bytes = len(outbuf_payload) + with self.outbuf_lock: + self.outbufs[-1].append(outbuf_payload) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + self.sent_continue = True + self._flush_some() + self.request.completed = False def received(self, data): """ Receives input asynchronously and assigns one or more requests to the channel. """ - # Preconditions: there's no task(s) already running - request = self.request - requests = [] - if not data: return False - while data: - if request is None: - request = self.parser_class(self.adj) - n = request.received(data) - - if request.expect_continue and request.headers_finished: - # guaranteed by parser to be a 1.1 request - request.expect_continue = False - - if not self.sent_continue: - # there's no current task, so we don't need to try to - # lock the outbuf to append to it. - outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" - num_bytes = len(outbuf_payload) - self.outbufs[-1].append(outbuf_payload) - self.current_outbuf_count += num_bytes - self.total_outbufs_len += num_bytes - self.sent_continue = True - self._flush_some() - request.completed = False - - if request.completed: - # The request (with the body) is ready to use. - self.request = None - - if not request.empty: - requests.append(request) - request = None - else: - self.request = request - - if n >= len(data): - break - data = data[n:] - - if requests: - self.requests = requests - self.server.add_task(self) + with self.requests_lock: + while data: + if self.request is None: + self.request = self.parser_class(self.adj) + n = self.request.received(data) + + # if there are requests queued, we can not send the continue + # header yet since the responses need to be kept in order + if ( + self.request.expect_continue + and self.request.headers_finished + and not self.requests + and not self.sent_continue + ): + self.send_continue() + + if self.request.completed: + # The request (with the body) is ready to use. + self.sent_continue = False + + if not self.request.empty: + self.requests.append(self.request) + if len(self.requests) == 1: + # self.requests was empty before so the main thread + # is in charge of starting the task. Otherwise, + # service() will add a new task after each request + # has been processed + self.server.add_task(self) + self.request = None + + if n >= len(data): + break + data = data[n:] return True @@ -360,88 +382,101 @@ class HTTPChannel(wasyncore.dispatcher): self.outbuf_lock.wait() def service(self): - """Execute all pending requests """ - with self.task_lock: - while self.requests: - request = self.requests[0] + """Execute one request. If there are more, we add another task to the + server at the end.""" + + request = self.requests[0] + + if request.error: + task = self.error_task_class(self, request) + else: + task = self.task_class(self, request) - if request.error: - task = self.error_task_class(self, request) + try: + if self.connected: + task.service() + else: + task.close_on_finish = True + except ClientDisconnected: + self.logger.info("Client disconnected while serving %s" % task.request.path) + task.close_on_finish = True + except Exception: + self.logger.exception("Exception while serving %s" % task.request.path) + + if not task.wrote_header: + if self.adj.expose_tracebacks: + body = traceback.format_exc() else: - task = self.task_class(self, request) + body = "The server encountered an unexpected internal server error" + req_version = request.version + req_headers = request.headers + err_request = self.parser_class(self.adj) + err_request.error = InternalServerError(body) + # copy some original request attributes to fulfill + # HTTP 1.1 requirements + err_request.version = req_version try: - task.service() + err_request.headers["CONNECTION"] = req_headers["CONNECTION"] + except KeyError: + pass + task = self.error_task_class(self, err_request) + try: + task.service() # must not fail except ClientDisconnected: - self.logger.info( - "Client disconnected while serving %s" % task.request.path - ) task.close_on_finish = True - except Exception: - self.logger.exception( - "Exception while serving %s" % task.request.path - ) + else: + task.close_on_finish = True - if not task.wrote_header: - if self.adj.expose_tracebacks: - body = traceback.format_exc() - else: - body = ( - "The server encountered an unexpected " - "internal server error" - ) - req_version = request.version - req_headers = request.headers - request = self.parser_class(self.adj) - request.error = InternalServerError(body) - # copy some original request attributes to fulfill - # HTTP 1.1 requirements - request.version = req_version - try: - request.headers["CONNECTION"] = req_headers["CONNECTION"] - except KeyError: - pass - task = self.error_task_class(self, request) - try: - task.service() # must not fail - except ClientDisconnected: - task.close_on_finish = True - else: - task.close_on_finish = True - # we cannot allow self.requests to drop to empty til - # here; otherwise the mainloop gets confused - - if task.close_on_finish: - self.close_when_flushed = True - - for request in self.requests: - request.close() - self.requests = [] - else: - # before processing a new request, ensure there is not too - # much data in the outbufs waiting to be flushed - # NB: currently readable() returns False while we are - # flushing data so we know no new requests will come in - # that we need to account for, otherwise it'd be better - # to do this check at the start of the request instead of - # at the end to account for consecutive service() calls - - if len(self.requests) > 1: - self._flush_outbufs_below_high_watermark() - - # this is a little hacky but basically it's forcing the - # next request to create a new outbuf to avoid sharing - # outbufs across requests which can cause outbufs to - # not be deallocated regularly when a connection is open - # for a long time - - if self.current_outbuf_count > 0: - self.current_outbuf_count = self.adj.outbuf_high_watermark - - request = self.requests.pop(0) + if task.close_on_finish: + with self.requests_lock: + self.close_when_flushed = True + + for request in self.requests: request.close() + self.requests = [] + else: + # before processing a new request, ensure there is not too + # much data in the outbufs waiting to be flushed + # NB: currently readable() returns False while we are + # flushing data so we know no new requests will come in + # that we need to account for, otherwise it'd be better + # to do this check at the start of the request instead of + # at the end to account for consecutive service() calls + + if len(self.requests) > 1: + self._flush_outbufs_below_high_watermark() + + # this is a little hacky but basically it's forcing the + # next request to create a new outbuf to avoid sharing + # outbufs across requests which can cause outbufs to + # not be deallocated regularly when a connection is open + # for a long time + + if self.current_outbuf_count > 0: + self.current_outbuf_count = self.adj.outbuf_high_watermark + + request.close() + + # Add new task to process the next request + with self.requests_lock: + self.requests.pop(0) + if self.connected and self.requests: + self.server.add_task(self) + elif ( + self.connected + and self.request is not None + and self.request.expect_continue + and self.request.headers_finished + and not self.sent_continue + ): + # A request waits for a signal to continue, but we could + # not send it until now because requests were being + # processed and the output needs to be kept in order + self.send_continue() if self.connected: self.server.pull_trigger() + self.last_activity = time.time() def cancel(self): diff --git a/src/waitress/runner.py b/src/waitress/runner.py index 4fb3e6b..c23ca0e 100644 --- a/src/waitress/runner.py +++ b/src/waitress/runner.py @@ -169,6 +169,12 @@ Tuning options: The use_poll argument passed to ``asyncore.loop()``. Helps overcome open file descriptors limit. Default is False. + --channel-request-lookahead=INT + Allows channels to stay readable and buffer more requests up to the + given maximum even if a request is already being processed. This allows + detecting if a client closed the connection while its request is being + processed. Default is 0. + """ RUNNER_PATTERN = re.compile( diff --git a/src/waitress/task.py b/src/waitress/task.py index 3a7cf17..2ac8f4c 100644 --- a/src/waitress/task.py +++ b/src/waitress/task.py @@ -560,6 +560,11 @@ class WSGITask(Task): if mykey not in environ: environ[mykey] = value + # Insert a callable into the environment that allows the application to + # check if the client disconnected. Only works with + # channel_request_lookahead larger than 0. + environ["waitress.client_disconnected"] = self.channel.check_client_disconnected + # cache the environ for this request self.environ = environ return environ diff --git a/tests/test_channel.py b/tests/test_channel.py index df3d450..d86dbbe 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,6 +1,8 @@ import io import unittest +import pytest + class TestHTTPChannel(unittest.TestCase): def _makeOne(self, sock, addr, adj, map=None): @@ -173,7 +175,7 @@ class TestHTTPChannel(unittest.TestCase): def test_readable_with_requests(self): inst, sock, map = self._makeOneWithMap() - inst.requests = True + inst.requests = [True] self.assertEqual(inst.readable(), False) def test_handle_read_no_error(self): @@ -189,8 +191,6 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(L, [b"abc"]) def test_handle_read_error(self): - import socket - inst, sock, map = self._makeOneWithMap() inst.will_close = False @@ -439,7 +439,7 @@ class TestHTTPChannel(unittest.TestCase): preq.completed = False preq.empty = True inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.requests, ()) + self.assertEqual(inst.requests, []) self.assertEqual(inst.server.tasks, []) def test_received_preq_completed_empty(self): @@ -525,14 +525,6 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(inst.sent_continue, True) self.assertEqual(preq.completed, False) - def test_service_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - def test_service_with_one_request(self): inst, sock, map = self._makeOneWithMap() request = DummyRequest() @@ -561,6 +553,7 @@ class TestHTTPChannel(unittest.TestCase): inst.task_class = DummyTaskClass() inst.requests = [request1, request2] inst.service() + inst.service() self.assertEqual(inst.requests, []) self.assertTrue(request1.serviced) self.assertTrue(request2.serviced) @@ -705,6 +698,137 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(inst.requests, []) +class TestHTTPChannelLookahead(TestHTTPChannel): + def app_check_disconnect(self, environ, start_response): + """ + Application that checks for client disconnection every + second for up to two seconds. + """ + import time + + if hasattr(self, "app_started"): + self.app_started.set() + + try: + request_body_size = int(environ.get("CONTENT_LENGTH", 0)) + except ValueError: + request_body_size = 0 + self.request_body = environ["wsgi.input"].read(request_body_size) + + self.disconnect_detected = False + check = environ["waitress.client_disconnected"] + if environ["PATH_INFO"] == "/sleep": + for i in range(3): + if i != 0: + time.sleep(1) + if check(): + self.disconnect_detected = True + break + + body = b"finished" + cl = str(len(body)) + start_response( + "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] + ) + return [body] + + def _make_app_with_lookahead(self): + """ + Setup a channel with lookahead and store it and the socket in self + """ + adj = DummyAdjustments() + adj.channel_request_lookahead = 5 + channel, sock, map = self._makeOneWithMap(adj=adj) + channel.server.application = self.app_check_disconnect + + self.channel = channel + self.sock = sock + + def _send(self, *lines): + """ + Send lines through the socket with correct line endings + """ + self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii")) + + def test_client_disconnect(self, close_before_start=False): + """Disconnect the socket after starting the task.""" + import threading + + self._make_app_with_lookahead() + self._send( + "GET /sleep HTTP/1.1", + "Host: localhost:8080", + "", + ) + self.assertTrue(self.channel.readable()) + self.channel.handle_read() + self.assertEqual(len(self.channel.server.tasks), 1) + self.app_started = threading.Event() + self.disconnect_detected = False + thread = threading.Thread(target=self.channel.server.tasks[0].service) + + if not close_before_start: + thread.start() + self.assertTrue(self.app_started.wait(timeout=5)) + + # Close the socket, check that the channel is still readable due to the + # lookahead and read it, which marks the channel as closed. + self.sock.close() + self.assertTrue(self.channel.readable()) + self.channel.handle_read() + + if close_before_start: + thread.start() + + thread.join() + + if close_before_start: + self.assertFalse(self.app_started.is_set()) + else: + self.assertTrue(self.disconnect_detected) + + def test_client_disconnect_immediate(self): + """ + The same test, but this time we close the socket even before processing + started. The app should not be executed. + """ + self.test_client_disconnect(close_before_start=True) + + def test_lookahead_continue(self): + """ + Send two requests to a channel with lookahead and use an + expect-continue on the second one, making sure the responses still come + in the right order. + """ + self._make_app_with_lookahead() + self._send( + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "", + "x", + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "Expect: 100-continue", + "", + ) + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + data = self.sock.recv(256).decode("ascii") + self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n")) + + self.sock.send(b"x") + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + self.channel._flush_some() + data = self.sock.recv(256).decode("ascii") + self.assertEqual(data.split("\r\n")[-1], "finished") + self.assertEqual(self.request_body, b"x") + + class DummySock: blocking = False closed = False @@ -731,6 +855,11 @@ class DummySock: self.sent += data return len(data) + def recv(self, buffer_size): + result = self.sent[:buffer_size] + self.sent = self.sent[buffer_size:] + return result + class DummyLock: notified = False @@ -796,11 +925,16 @@ class DummyAdjustments: expose_tracebacks = True ident = "waitress" max_request_header_size = 10000 + url_prefix = "" + channel_request_lookahead = 0 + max_request_body_size = 1048576 class DummyServer: trigger_pulled = False adj = DummyAdjustments() + effective_port = 8080 + server_name = "" def __init__(self): self.tasks = [] diff --git a/tests/test_task.py b/tests/test_task.py index de800fb..ea71e02 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -802,6 +802,7 @@ class TestWSGITask(unittest.TestCase): "SERVER_PORT", "SERVER_PROTOCOL", "SERVER_SOFTWARE", + "waitress.client_disconnected", "wsgi.errors", "wsgi.file_wrapper", "wsgi.input", @@ -958,6 +959,10 @@ class DummyChannel: creation_time = 0 addr = ("127.0.0.1", 39830) + def check_client_disconnected(self): + # For now, until we have tests handling this feature + return False + def __init__(self, server=None): if server is None: server = DummyServer() |