diff options
-rw-r--r-- | src/waitress/channel.py | 46 | ||||
-rw-r--r-- | tests/test_channel.py | 95 |
2 files changed, 122 insertions, 19 deletions
diff --git a/src/waitress/channel.py b/src/waitress/channel.py index 7d1f385..948b498 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -78,6 +78,7 @@ class HTTPChannel(wasyncore.dispatcher): may occasionally check if the client has disconnected and interrupt execution. """ + return not self.connected def writable(self): @@ -116,23 +117,30 @@ class HTTPChannel(wasyncore.dispatcher): # right now. flush = None + self._flush_exception(flush) + + if self.close_when_flushed and not self.total_outbufs_len: + self.close_when_flushed = False + self.will_close = True + + if self.will_close: + self.handle_close() + + def _flush_exception(self, flush): if flush: try: - flush() + return (flush(), False) except OSError: if self.adj.log_socket_errors: self.logger.exception("Socket error") self.will_close = True + + return (False, True) except Exception: # pragma: nocover self.logger.exception("Unexpected exception when flushing") self.will_close = True - if self.close_when_flushed and not self.total_outbufs_len: - self.close_when_flushed = False - self.will_close = True - - if self.will_close: - self.handle_close() + return (False, True) def readable(self): # We might want to read more requests. We can only do this if: @@ -190,6 +198,7 @@ class HTTPChannel(wasyncore.dispatcher): Receives input asynchronously and assigns one or more requests to the channel. """ + if not data: return False @@ -201,6 +210,7 @@ class HTTPChannel(wasyncore.dispatcher): # if there are requests queued, we can not send the continue # header yet since the responses need to be kept in order + if ( self.request.expect_continue and self.request.headers_finished @@ -215,6 +225,7 @@ class HTTPChannel(wasyncore.dispatcher): if not self.request.empty: self.requests.append(self.request) + if len(self.requests) == 1: # self.requests was empty before so the main thread # is in charge of starting the task. Otherwise, @@ -363,7 +374,14 @@ class HTTPChannel(wasyncore.dispatcher): self.total_outbufs_len += num_bytes if self.total_outbufs_len >= self.adj.send_bytes: - self.server.pull_trigger() + (flushed, exception) = self._flush_exception(self._flush_some) + + if ( + exception + or not flushed + or self.total_outbufs_len >= self.adj.send_bytes + ): + self.server.pull_trigger() return num_bytes @@ -374,6 +392,17 @@ class HTTPChannel(wasyncore.dispatcher): if self.total_outbufs_len > self.adj.outbuf_high_watermark: with self.outbuf_lock: + (_, exception) = self._flush_exception(self._flush_some) + + if exception: + # An exception happened while flushing, wake up the main + # thread, then wait for it to decide what to do next + # (probably close the socket, and then just return) + self.server.pull_trigger() + self.outbuf_lock.wait() + + return + while ( self.connected and self.total_outbufs_len > self.adj.outbuf_high_watermark @@ -460,6 +489,7 @@ class HTTPChannel(wasyncore.dispatcher): # Add new task to process the next request with self.requests_lock: self.requests.pop(0) + if self.connected and self.requests: self.server.add_task(self) elif ( diff --git a/tests/test_channel.py b/tests/test_channel.py index d86dbbe..b1c317d 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -213,6 +213,13 @@ class TestHTTPChannel(unittest.TestCase): def test_write_soon_nonempty_byte(self): inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + wrote = inst.write_soon(b"a") self.assertEqual(wrote, 1) self.assertEqual(len(inst.outbufs[0]), 1) @@ -224,14 +231,19 @@ class TestHTTPChannel(unittest.TestCase): wrapper = ReadOnlyFileBasedBuffer(f, 8192) wrapper.prepare() inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + outbufs = inst.outbufs - orig_outbuf = outbufs[0] wrote = inst.write_soon(wrapper) self.assertEqual(wrote, 3) - self.assertEqual(len(outbufs), 3) - self.assertEqual(outbufs[0], orig_outbuf) - self.assertEqual(outbufs[1], wrapper) - self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer") + self.assertEqual(len(outbufs), 2) + self.assertEqual(outbufs[0], wrapper) + self.assertEqual(outbufs[1].__class__.__name__, "OverflowableBuffer") def test_write_soon_disconnected(self): from waitress.channel import ClientDisconnected @@ -253,16 +265,29 @@ class TestHTTPChannel(unittest.TestCase): def test_write_soon_rotates_outbuf_on_overflow(self): inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + inst.adj.outbuf_high_watermark = 3 inst.current_outbuf_count = 4 wrote = inst.write_soon(b"xyz") self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") + self.assertEqual(len(inst.outbufs), 1) + self.assertEqual(inst.outbufs[0].get(), b"xyz") def test_write_soon_waits_on_backpressure(self): inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + inst.adj.outbuf_high_watermark = 3 inst.total_outbufs_len = 4 inst.current_outbuf_count = 4 @@ -275,11 +300,59 @@ class TestHTTPChannel(unittest.TestCase): inst.outbuf_lock = Lock() wrote = inst.write_soon(b"xyz") self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") + self.assertEqual(len(inst.outbufs), 1) + self.assertEqual(inst.outbufs[0].get(), b"xyz") self.assertTrue(inst.outbuf_lock.waited) + def test_write_soon_attempts_flush_high_water_and_exception(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush, it will raise Exception, which + # disconnects the remote end + def send(_): + inst.connected = False + raise Exception() + + sock.send = send + + inst.adj.outbuf_high_watermark = 3 + inst.total_outbufs_len = 4 + inst.current_outbuf_count = 4 + + inst.outbufs[0].append(b"test") + + class Lock(DummyLock): + def wait(self): + inst.total_outbufs_len = 0 + super().wait() + + inst.outbuf_lock = Lock() + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"xyz")) + + # Validate we woke up the main thread to deal with the exception of + # trying to send + self.assertTrue(inst.outbuf_lock.waited) + self.assertTrue(inst.server.trigger_pulled) + + def test_write_soon_flush_and_exception(self): + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush, it will raise Exception, which + # disconnects the remote end + def send(_): + inst.connected = False + raise Exception() + + sock.send = send + + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + # Validate we woke up the main thread to deal with the exception of + # trying to send + self.assertTrue(inst.server.trigger_pulled) + def test_handle_write_notify_after_flush(self): inst, sock, map = self._makeOneWithMap() inst.requests = [True] |