diff options
author | Michael Merickel <michael@merickel.org> | 2019-03-26 00:29:14 -0500 |
---|---|---|
committer | Michael Merickel <michael@merickel.org> | 2019-03-26 00:40:54 -0500 |
commit | 29fb97ddc15632f53c8201820f33a1793d63928d (patch) | |
tree | b80583e5c213a36a17df3937ab7b8ae5f068a1d5 | |
parent | 8111dd66188ffa225fd13ece31f6f9fcb34495ba (diff) | |
download | waitress-29fb97ddc15632f53c8201820f33a1793d63928d.tar.gz |
interrupt the app_iter if it tries to write to a closed socket
-rw-r--r-- | CHANGES.txt | 11 | ||||
-rw-r--r-- | waitress/channel.py | 13 | ||||
-rw-r--r-- | waitress/tests/test_channel.py | 31 |
3 files changed, 55 insertions, 0 deletions
diff --git a/CHANGES.txt b/CHANGES.txt index 4815499..3521864 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,14 @@ +unreleased +---------- + +Bugfixes +~~~~~~~~ + +- Stop early and close the ``app_iter`` when attempting to write to a closed + socket due to a client disconnect. This should notify a long-lived streaming + response when a client hangs up. + See https://github.com/Pylons/waitress/pull/238 + 1.2.1 (2019-01-25) ------------------ diff --git a/waitress/channel.py b/waitress/channel.py index 7ed3461..c625518 100644 --- a/waitress/channel.py +++ b/waitress/channel.py @@ -32,6 +32,9 @@ from waitress.utilities import InternalServerError from . import wasyncore +class ClientDisconnected(Exception): + """ Raised when attempting to write to a closed socket.""" + class HTTPChannel(wasyncore.dispatcher, object): """ Setting self.requests = [somerequest] prevents more requests from being @@ -305,6 +308,12 @@ class HTTPChannel(wasyncore.dispatcher, object): # def write_soon(self, data): + if not self.connected: + # if the socket is closed then interrupt the task so that it + # can cleanup possibly before the app_iter is exhausted + # XXX a simple boolean check should not require a lock since it's + # unidirectional (True -> False) + raise ClientDisconnected if data: # the async mainloop might be popping data off outbuf; we can # block here waiting for it because we're in a task thread @@ -334,6 +343,10 @@ class HTTPChannel(wasyncore.dispatcher, object): task = self.task_class(self, request) try: task.service() + except ClientDisconnected: + self.logger.warn('Client disconnected when serving %s' % + task.request.path) + task.close_on_finish = True except: self.logger.exception('Exception when serving %s' % task.request.path) diff --git a/waitress/tests/test_channel.py b/waitress/tests/test_channel.py index afe6e51..d550f87 100644 --- a/waitress/tests/test_channel.py +++ b/waitress/tests/test_channel.py @@ -225,6 +225,12 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(outbufs[1], wrapper) self.assertEqual(outbufs[2].__class__.__name__, 'OverflowableBuffer') + def test_write_soon_disconnected(self): + from waitress.channel import ClientDisconnected + inst, sock, map = self._makeOneWithMap() + inst.connected = False + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b'stuff')) + def test__flush_some_empty_outbuf(self): inst, sock, map = self._makeOneWithMap() result = inst._flush_some() @@ -558,6 +564,27 @@ class TestHTTPChannel(unittest.TestCase): self.assertTrue(inst.close_when_flushed) self.assertTrue(request.closed) + def test_service_with_request_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ClientDisconnected) + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.warnings), 1) + self.assertTrue(inst.force_flush) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, False) + self.assertTrue(request.closed) + def test_cancel_no_requests(self): inst, sock, map = self._makeOneWithMap() inst.requests = () @@ -699,6 +726,10 @@ class DummyLogger(object): def __init__(self): self.exceptions = [] + self.warnings = [] + + def warn(self, msg): + self.warnings.append(msg) def exception(self, msg): self.exceptions.append(msg) |