diff options
author | Chris McDonough <chrism@plope.com> | 2012-01-02 04:30:50 -0500 |
---|---|---|
committer | Chris McDonough <chrism@plope.com> | 2012-01-02 04:30:50 -0500 |
commit | 0475bf347945223636c512c59dc370637be2d51c (patch) | |
tree | 97aec045e0d31eec0475dda9aa490c35ed357406 | |
parent | c4c93020af966f282449d75c332e67c148e41dc5 (diff) | |
download | waitress-0475bf347945223636c512c59dc370637be2d51c.tar.gz |
Features
~~~~~~~~
- Dont hang a thread up trying to send data to slow clients.
- Use self.logger to log socket errors instead of self.log_info (normalize).
- Remove pointless handle_error method from channel.
- Queue requests instead of tasks in a channel.
Bug Fixes
~~~~~~~~~
- Expect: 100-continue responses were broken.
-rw-r--r-- | CHANGES.txt | 20 | ||||
-rw-r--r-- | TODO.txt | 6 | ||||
-rw-r--r-- | waitress/__init__.py | 2 | ||||
-rw-r--r-- | waitress/adjustments.py | 4 | ||||
-rw-r--r-- | waitress/buffers.py | 4 | ||||
-rw-r--r-- | waitress/channel.py | 338 | ||||
-rw-r--r-- | waitress/server.py | 2 | ||||
-rw-r--r-- | waitress/task.py | 11 | ||||
-rw-r--r-- | waitress/tests/test_channel.py | 404 | ||||
-rw-r--r-- | waitress/tests/test_server.py | 2 | ||||
-rw-r--r-- | waitress/tests/test_task.py | 2 |
11 files changed, 425 insertions, 370 deletions
diff --git a/CHANGES.txt b/CHANGES.txt index 30b58ad..30762da 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,23 @@ +Next release +----------- + +Features +~~~~~~~~ + +- Dont hang a thread up trying to send data to slow clients. + +- Use self.logger to log socket errors instead of self.log_info (normalize). + +- Remove pointless handle_error method from channel. + +- Queue requests instead of tasks in a channel. + +Bug Fixes +~~~~~~~~~ + +- Expect: 100-continue responses were broken. + + 0.2 (2011-12-31) ---------------- @@ -2,10 +2,6 @@ - Speed tweaking. -- Waitress version number in header. - -- Test paste server runner on windows. - - Anticipate multivalue and single-value-only headers in request headers in parser.py. @@ -13,7 +9,5 @@ - Complex pipelining functests (with intermediate connection: close). -- Get rid of overflowable inbuf. - - Killthreads support. diff --git a/waitress/__init__.py b/waitress/__init__.py index c501329..9dff87e 100644 --- a/waitress/__init__.py +++ b/waitress/__init__.py @@ -3,7 +3,7 @@ import logging def serve(app, **kw): _server = kw.pop('_server', WSGIServer) # test shim - _quiet = kw.pop('_quiet', False) + _quiet = kw.pop('_quiet', False) # test shim if not _quiet: # pragma: no cover # idempotent if logging has already been set up logging.basicConfig() diff --git a/waitress/adjustments.py b/waitress/adjustments.py index 808a0e9..53560b0 100644 --- a/waitress/adjustments.py +++ b/waitress/adjustments.py @@ -78,8 +78,8 @@ class Adjustments(object): expose_tracebacks = False # The socket options to set on receiving a connection. It is a list of - # (level, optname, value) tuples. TCP_NODELAY is probably good for Zope, - # since Zope buffers data itself. + # (level, optname, value) tuples. TCP_NODELAY disables the Nagle + # algorithm for writes (Waitress already buffers its writes). socket_options = [ (socket.SOL_TCP, socket.TCP_NODELAY, 1), ] diff --git a/waitress/buffers.py b/waitress/buffers.py index 6e8eab2..92891f5 100644 --- a/waitress/buffers.py +++ b/waitress/buffers.py @@ -13,9 +13,10 @@ ############################################################################## """Buffers """ - from io import BytesIO +from waitress.compat import thread + # copy_bytes controls the size of temp. strings for shuffling data around. COPY_BYTES = 1 << 18 # 256K @@ -149,6 +150,7 @@ class OverflowableBuffer(object): def __init__(self, overflow): # overflow is the maximum to be stored in a StringIO buffer. self.overflow = overflow + self.lock = thread.allocate_lock() # API def __len__(self): buf = self.buf diff --git a/waitress/channel.py b/waitress/channel.py index 2a629b5..3bdfa17 100644 --- a/waitress/channel.py +++ b/waitress/channel.py @@ -15,14 +15,14 @@ """ import asyncore import socket -import sys import time import traceback -from waitress.compat import thread from waitress.buffers import OverflowableBuffer from waitress.parser import HTTPRequestParser +from waitress.compat import thread + from waitress.task import ( ErrorTask, WSGITask, @@ -36,22 +36,23 @@ from waitress.utilities import ( class HTTPChannel(logging_dispatcher, object): """Channel that switches between asynchronous and synchronous mode. - Set self.task = sometask before using a channel in a thread other than - the thread handling the main loop. + Set self.requests = [somerequest] before using a channel in a thread other + than the thread handling the main loop. - Set self.task = None to give the channel back to the thread handling + Set self.requests = [] to give the channel back to the thread handling the main loop. """ task_class = WSGITask error_task_class = ErrorTask parser_class = HTTPRequestParser - task_lock = thread.allocate_lock() # syncs access to task-related attrs - - request = None # A request parser instance - last_activity = 0 # Time of last activity - will_close = False # will_close is set to True to close the socket. - task = None # currently running task + request = None # A request parser instance + last_activity = 0 # Time of last activity + will_close = False # set to True to close the socket. + requests = () # currently pending requests + sent_continue = False # used as a latch after sending 100 continue + task_lock = thread.allocate_lock() # lock used to push/pop requests + force_flush = False # indicates a need to flush the outbuf # # ASYNCHRONOUS METHODS (including __init__) @@ -69,43 +70,147 @@ class HTTPChannel(logging_dispatcher, object): self.addr = addr self.adj = adj self.outbuf = OverflowableBuffer(adj.outbuf_overflow) - self.inbuf = OverflowableBuffer(adj.inbuf_overflow) self.creation_time = self.last_activity = time.time() asyncore.dispatcher.__init__(self, sock, map=map) def writable(self): - if self.task is not None: - return False - return self.will_close or self.outbuf + return bool(self.outbuf) def handle_write(self): - if self.task is not None: + # Precondition: there's data in the out buffer to be sent + if not self.connected: return - if self.outbuf: + + if not self.requests: + # 1. There are no running tasks, so we don't need to try to lock + # the outbuf before sending + # 2. The data in the out buffer should be sent as soon as possible + # because it's either data left over from task output + # or a 100 Continue line sent within "received". + flush = self._flush_some + elif self.force_flush: + # 1. There's a running task, so we need to try to lock + # the outbuf before sending + # 2. This is the last chunk sent by the Nth of M tasks in a + # sequence on this channel, so flush it regardless of whether + # it's >= self.adj.send_bytes. We need to do this now, or it + # won't get done. + flush = self._flush_some_if_lockable + self.force_flush = False + elif (len(self.outbuf) >= self.adj.send_bytes): + # 1. There's a running task, so we need to try to lock + # the outbuf before sending + # 2. Only try to send if the data in the out buffer is larger + # than self.adj_bytes to avoid TCP fragmentation + flush = self._flush_some_if_lockable + self.force_flush = False + else: + # 1. There's not enough data in the out buffer to bother to send + # right now. + flush = None + + if flush: try: - self._flush_some() + flush() except socket.error: - self.handle_comm_error() - elif self.will_close: + if self.adj.log_socket_errors: + self.logger.exception('Socket error') + self.will_close = True + if self.will_close: self.handle_close() - self.last_activity = time.time() def readable(self): - if self.task is not None or self.will_close: - return False - if self.inbuf: - self.received() - return not self.will_close + # We might want to create a new task. We can only do this if: + # 1. We're not already about to close the connection. + # 2. There's no already currently running task(s). + # 3. There's no data in the output buffer that needs to be sent + # before we potentially create a new task. + return not (self.will_close or self.requests or self.outbuf) def handle_read(self): try: data = self.recv(self.adj.recv_bytes) except socket.error: - self.handle_comm_error() + if self.adj.log_socket_errors: + self.logger.exception('Socket error') + self.handle_close() return - self.last_activity = time.time() if data: - self.inbuf.append(data) + self.last_activity = time.time() + self.received(data) + + def received(self, data): + """ + Receives input asynchronously and assigns a task to the channel. + """ + # Preconditions: there's no task(s) already running + request = self.request + requests = [] + + if not data: + return False + + while data: + if request is None: + request = self.parser_class(self.adj) + n = request.received(data) + if request.expect_continue and request.headers_finished: + # guaranteed by parser to be a 1.1 request + request.expect_continue = False + if not self.sent_continue: + # there's no current task, so we don't need to try to + # lock the outbuf to append to it. + self.outbuf.append(b'HTTP/1.1 100 Continue\r\n\r\n') + self.sent_expect_continue = True + self._flush_some() + request.completed = False + if request.completed: + # The request (with the body) is ready to use. + self.request = None + if not request.empty: + requests.append(request) + request = None + else: + self.request = request + if n >= len(data): + break + data = data[n:] + + if requests: + self.requests = requests + self.server.add_task(self) + + return True + + def _flush_some_if_lockable(self): + # Since our task may be appending to the outbuf, we try to acquire + # the lock, but we don't block if we can't. + outbuf = self.outbuf + locked = outbuf.lock.acquire(0) + if locked: + try: + self._flush_some() + finally: + outbuf.lock.release() + + def _flush_some(self): + # Send as much data as possible to our client + outbuf = self.outbuf + outbuflen = len(outbuf) + while outbuflen > 0: + chunk = outbuf.get(self.adj.send_bytes) + num_sent = self.send(chunk) + if num_sent: + outbuf.skip(num_sent, True) + outbuflen -= num_sent + else: + return False + self.last_activity = time.time() + return True + + def handle_close(self): + self.connected = False + asyncore.dispatcher.close(self) def add_channel(self, map=None): """See asyncore.dispatcher @@ -126,149 +231,62 @@ class HTTPChannel(logging_dispatcher, object): if fd in ac: del ac[fd] - def received(self): - """ - Receives input asynchronously and assigns a task to the channel. - """ - if self.task is not None: - return False - chunk = self.inbuf.get(self.adj.recv_bytes) - if not chunk: - return False - if self.request is None: - self.request = self.parser_class(self.adj) - request = self.request - n = request.received(chunk) - if n: - self.inbuf.skip(n, True) - if request.expect_continue and request.headers_finished: - # guaranteed by parser to be a 1.1 request - self.write(b'HTTP/1.1 100 Continue\r\n\r\n') - request.expect_continue = False - if request.completed: - # The request (with the body) is ready to use. - if request.connection_close and self.inbuf: - self.inbuf = OverflowableBuffer(self.adj.inbuf_overflow) - self.request = None - if not request.empty: - if request.error: - self.inbuf = OverflowableBuffer(self.adj.inbuf_overflow) - task = self.error_task_class(self, request) - else: - task = self.task_class(self, request) - self.task = task - self.server.add_task(self) - return - if self.inbuf: - self.server.pull_trigger() - - def handle_error(self, exc_info=None): # exc_info for tests - """See async.dispatcher - - Handles program errors (not communication errors) - """ - if exc_info is None: # pragma: no cover - t, v = sys.exc_info()[:2] - else: - t, v = exc_info[:2] - if t is SystemExit or t is KeyboardInterrupt: - raise t(v) - asyncore.dispatcher.handle_error(self) - - def handle_comm_error(self): - """ - Handles communication errors (not program errors) - """ - if self.adj.log_socket_errors: - # handle_error calls close - self.handle_error() - else: - # Ignore socket errors. - self.handle_close() - # - # METHODS USED IN BOTH MODES + # SYNCHRONOUS METHODS # - def handle_close(self): - # Always close in asynchronous mode. If the connection is - # closed in a thread, the main loop can end up with a bad file - # descriptor. - # XXX this method is probably called in a thread by virtue of - # "handle_comm_error" - if self.task is not None: - self.will_close = True - return - self.connected = False - asyncore.dispatcher.close(self) - - def write(self, data): - wrote = 0 + def write_soon(self, data): if data: - self.outbuf.append(data) - wrote = len(data) - - while len(self.outbuf) >= self.adj.send_bytes: - # Send what we can without blocking. - # We propagate errors to the application on purpose - # (to stop the application if the connection closes). - if not self._flush_some(): # pragma: no cover (coverage bug?) - break - - return wrote - - def _flush_some(self): - """Flushes data. - - Returns True if some data was sent.""" - outbuf = self.outbuf - if outbuf and self.connected: - chunk = outbuf.get(self.adj.send_bytes) - num_sent = self.send(chunk) - if num_sent: - outbuf.skip(num_sent, True) - return True - return False - - # - # ITask implementation. Delegates to the queued tasks. - # + # the async mainloop might be popping data off outbuf; we can + # block here waiting for it because we're in a thread + with self.outbuf.lock: + self.outbuf.append(data) + return len(data) + return 0 def service(self): """Execute a pending task""" - if self.task is None: - return - task = self.task - try: - task.service() - except: - self.logger.exception('Exception when serving %s' % - task.request.uri) - if not task.wrote_header: - if self.adj.expose_tracebacks: - body = traceback.format_exc() + with self.task_lock: + while self.requests: + request = self.requests[0] + if request.error: + task = self.error_task_class(self, request) else: - body = ('The server encountered an unexpected internal ' - 'server error') - request = self.parser_class(self.adj) - request.error = InternalServerError(body) - task = self.error_task_class(self, request) - task.service() # must not fail - else: - task.close_on_finish = True - while self._flush_some(): - pass - self.task = None - if task.close_on_finish: - self.will_close = True + task = self.task_class(self, request) + try: + task.service() + except: + self.logger.exception('Exception when serving %s' % + task.request.uri) + if not task.wrote_header: + if self.adj.expose_tracebacks: + body = traceback.format_exc() + else: + body = ('The server encountered an unexpected ' + 'internal server error') + request = self.parser_class(self.adj) + request.error = InternalServerError(body) + task = self.error_task_class(self, request) + task.service() # must not fail + else: + task.close_on_finish = True + # we cannot allow self.requests to drop to empty til + # here; otherwise the mainloop gets confused + if task.close_on_finish: + self.will_close = True + self.requests = [] + else: + self.requests.pop(0) + + self.force_flush = True self.server.pull_trigger() self.last_activity = time.time() def cancel(self): - """Cancels all pending tasks""" - if self.task is not None: - self.task.cancel() - self.task = None + """ Cancels all pending requests """ + self.force_flush = True + self.last_activity = time.time() + self.requests = [] def defer(self): pass diff --git a/waitress/server.py b/waitress/server.py index 7b9ec95..80a0459 100644 --- a/waitress/server.py +++ b/waitress/server.py @@ -141,6 +141,6 @@ class WSGIServer(logging_dispatcher, object): """ cutoff = now - self.adj.channel_timeout for channel in self.active_channels.values(): - if (channel.task is None) and channel.last_activity < cutoff: + if (not channel.requests) and channel.last_activity < cutoff: channel.will_close = True diff --git a/waitress/task.py b/waitress/task.py index 0dad088..1fa57ea 100644 --- a/waitress/task.py +++ b/waitress/task.py @@ -266,7 +266,7 @@ class Task(object): self.write(b'') if self.chunked_response: # not self.write, it will chunk it! - self.channel.write(b'0\r\n\r\n') + self.channel.write_soon(b'0\r\n\r\n') def write(self, data): if not self.complete: @@ -275,7 +275,7 @@ class Task(object): channel = self.channel if not self.wrote_header: rh = self.build_response_header() - channel.write(rh) + channel.write_soon(rh) self.wrote_header = True if data: towrite = data @@ -293,7 +293,7 @@ class Task(object): 'bytes specified by Content-Length header (%s)' % cl) self.logged_write_excess = True if towrite: - channel.write(towrite) + self.channel.write_soon(towrite) class ErrorTask(Task): """ An error task produces an error response """ @@ -392,9 +392,8 @@ class WSGITask(Task): self.content_length = first_chunk_len # transmit headers only after first iteration of the iterable # that returns a non-empty bytestring (PEP 3333) - if not chunk: - continue - self.write(chunk) + if chunk: + self.write(chunk) cl = self.content_length if cl != -1: diff --git a/waitress/tests/test_channel.py b/waitress/tests/test_channel.py index e664302..26b209c 100644 --- a/waitress/tests/test_channel.py +++ b/waitress/tests/test_channel.py @@ -19,36 +19,33 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(inst.addr, '127.0.0.1') self.assertEqual(map[100], inst) - def test_writable_no_task_will_close(self): + def test_writable_something_in_outbuf(self): inst, sock, map = self._makeOneWithMap() - inst.task = None - inst.will_close = True - inst.outbuf = '' + inst.outbuf = 'abc' self.assertTrue(inst.writable()) - def test_writable_no_task_outbuf(self): + def test_writable_nothing_in_outbuf(self): inst, sock, map = self._makeOneWithMap() - inst.task = None - inst.will_close = False - inst.outbuf ='a' - self.assertTrue(inst.writable()) + inst.outbuf = '' + self.assertFalse(inst.writable()) - def test_writable_with_task(self): + def test_handle_write_not_connected(self): inst, sock, map = self._makeOneWithMap() - inst.task = True - self.assertFalse(inst.writable()) + inst.outbuf = '' + inst.connected = False + self.assertFalse(inst.handle_write()) - def test_handle_write_with_task(self): + def test_handle_write_with_requests(self): inst, sock, map = self._makeOneWithMap() - inst.task = True + inst.requests = True inst.last_activity = 0 result = inst.handle_write() self.assertEqual(result, None) self.assertEqual(inst.last_activity, 0) - def test_handle_write_no_task_with_outbuf(self): + def test_handle_write_no_request_with_outbuf(self): inst, sock, map = self._makeOneWithMap() - inst.task = None + inst.requests = [] inst.outbuf = DummyBuffer(b'abc') inst.last_activity = 0 result = inst.handle_write() @@ -56,24 +53,23 @@ class TestHTTPChannel(unittest.TestCase): self.assertNotEqual(inst.last_activity, 0) self.assertEqual(sock.sent, b'abc') - def test_handle_write_no_task_with_outbuf_raises_socketerror(self): + def test_handle_write_outbuf_raises_socketerror(self): import socket inst, sock, map = self._makeOneWithMap() - inst.task = None - L = [] - inst.log_info = lambda *x: L.append(x) + inst.requests = [] inst.outbuf = DummyBuffer(b'abc', socket.error) inst.last_activity = 0 + inst.logger = DummyLogger() result = inst.handle_write() self.assertEqual(result, None) - self.assertNotEqual(inst.last_activity, 0) + self.assertEqual(inst.last_activity, 0) self.assertEqual(sock.sent, b'') - self.assertEqual(len(L), 1) + self.assertEqual(len(inst.logger.exceptions), 1) - def test_handle_write_no_task_no_outbuf_will_close(self): + def test_handle_write_no_requests_no_outbuf_will_close(self): inst, sock, map = self._makeOneWithMap() - inst.task = None - inst.outbuf = None + inst.requests = [] + inst.outbuf = DummyBuffer(b'') inst.will_close = True inst.last_activity = 0 result = inst.handle_write() @@ -82,41 +78,61 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(sock.closed, True) self.assertNotEqual(inst.last_activity, 0) - def test_readable_no_task_not_will_close(self): + def test_handle_write_no_requests_force_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbuf = DummyBuffer(b'abc') + inst.will_close = False + inst.force_flush = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf.lock.acquired) + self.assertEqual(inst.force_flush, False) + self.assertEqual(sock.sent, b'abc') + + def test_handle_write_no_requests_outbuf_gt_send_bytes(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbuf = DummyBuffer(b'abc') + inst.adj.send_bytes = 2 + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf.lock.acquired) + self.assertEqual(sock.sent, b'abc') + + def test_readable_no_requests_not_will_close(self): inst, sock, map = self._makeOneWithMap() - inst.task = None + inst.requests = [] inst.will_close = False self.assertEqual(inst.readable(), True) - def test_readable_no_task_will_close(self): + def test_readable_no_requests_will_close(self): inst, sock, map = self._makeOneWithMap() - inst.task = None + inst.requests = [] inst.will_close = True self.assertEqual(inst.readable(), False) - def test_readable_with_task(self): + def test_readable_with_requests(self): inst, sock, map = self._makeOneWithMap() - inst.task = True + inst.requests = True self.assertEqual(inst.readable(), False) - def test_readable_with_nonempty_inbuf_calls_received(self): - inst, sock, map = self._makeOneWithMap() - L = [] - inst.received = lambda: L.append(True) - inst.task = None - inst.inbuf = True - self.assertEqual(inst.readable(), True) - self.assertEqual(len(L), 1) - def test_handle_read_no_error(self): inst, sock, map = self._makeOneWithMap() inst.will_close = False inst.recv = lambda *arg: b'abc' inst.last_activity = 0 + L = [] + inst.received = lambda x: L.append(x) result = inst.handle_read() self.assertEqual(result, None) self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(inst.inbuf.get(100), b'abc') + self.assertEqual(L, [b'abc']) def test_handle_read_error(self): import socket @@ -124,59 +140,32 @@ class TestHTTPChannel(unittest.TestCase): inst.will_close = False def recv(b): raise socket.error inst.recv = recv - L = [] - inst.log_info = lambda *x: L.append(x) inst.last_activity = 0 + inst.logger = DummyLogger() result = inst.handle_read() self.assertEqual(result, None) self.assertEqual(inst.last_activity, 0) - self.assertEqual(len(L), 1) - self.assertEqual(inst.inbuf.get(100), b'') + self.assertEqual(len(inst.logger.exceptions), 1) - def test_write_empty_byte(self): + def test_write_soon_empty_byte(self): inst, sock, map = self._makeOneWithMap() - wrote = inst.write(b'') + wrote = inst.write_soon(b'') self.assertEqual(wrote, 0) + self.assertEqual(len(inst.outbuf), 0) - def test_write_nonempty_byte(self): + def test_write_soon_nonempty_byte(self): inst, sock, map = self._makeOneWithMap() - wrote = inst.write(b'a') + wrote = inst.write_soon(b'a') self.assertEqual(wrote, 1) - - def test_write_outbuf_gt_send_bytes_has_data(self): - from waitress.adjustments import Adjustments - class DummyAdj(Adjustments): - send_bytes = 10 - inst, sock, map = self._makeOneWithMap(adj=DummyAdj) - wrote = inst.write(b'x' * 1024) - self.assertEqual(wrote, 1024) - - def test_write_outbuf_gt_send_bytes_no_data(self): - from waitress.adjustments import Adjustments - class DummyAdj(Adjustments): - send_bytes = 10 - inst, sock, map = self._makeOneWithMap(adj=DummyAdj) - inst.outbuf.append(b'x' * 20) - self.connected = False - wrote = inst.write(b'') - self.assertEqual(wrote, 0) - - def test__flush_some_notconnected(self): - inst, sock, map = self._makeOneWithMap() - inst.outbuf = b'123' - inst.connected = False - result = inst._flush_some() - self.assertEqual(result, False) + self.assertEqual(len(inst.outbuf), 1) def test__flush_some_empty_outbuf(self): inst, sock, map = self._makeOneWithMap() - inst.connected = True result = inst._flush_some() - self.assertEqual(result, False) + self.assertEqual(result, True) def test__flush_some_full_outbuf_socket_returns_nonzero(self): inst, sock, map = self._makeOneWithMap() - inst.connected = True inst.outbuf.append(b'abc') result = inst._flush_some() self.assertEqual(result, True) @@ -184,24 +173,16 @@ class TestHTTPChannel(unittest.TestCase): def test__flush_some_full_outbuf_socket_returns_zero(self): inst, sock, map = self._makeOneWithMap() sock.send = lambda x: False - inst.connected = True inst.outbuf.append(b'abc') result = inst._flush_some() self.assertEqual(result, False) - def test_handle_close_no_task(self): + def test_handle_close(self): inst, sock, map = self._makeOneWithMap() - inst.task = None inst.handle_close() self.assertEqual(inst.connected, False) self.assertEqual(sock.closed, True) - def test_handle_close_with_task(self): - inst, sock, map = self._makeOneWithMap() - inst.task = True - inst.handle_close() - self.assertEqual(inst.will_close, True) - def test_add_channel(self): inst, sock, map = self._makeOneWithMap() fileno = inst._fileno @@ -220,19 +201,13 @@ class TestHTTPChannel(unittest.TestCase): def test_received(self): inst, sock, map = self._makeOneWithMap() inst.server = DummyServer() - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() + inst.received(b'GET / HTTP/1.1\n\n') self.assertEqual(inst.server.tasks, [inst]) - self.assertTrue(inst.task) - - def test_received_with_task(self): - inst, sock, map = self._makeOneWithMap() - inst.task = True - self.assertEqual(inst.received(), False) + self.assertTrue(inst.requests) def test_received_no_chunk(self): inst, sock, map = self._makeOneWithMap() - self.assertEqual(inst.received(), False) + self.assertEqual(inst.received(b''), False) def test_received_preq_not_completed(self): inst, sock, map = self._makeOneWithMap() @@ -241,9 +216,8 @@ class TestHTTPChannel(unittest.TestCase): inst.request = preq preq.completed = False preq.empty = True - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() - self.assertEqual(inst.task, None) + inst.received(b'GET / HTTP/1.1\n\n') + self.assertEqual(inst.requests, ()) self.assertEqual(inst.server.tasks, []) def test_received_preq_completed_empty(self): @@ -253,8 +227,7 @@ class TestHTTPChannel(unittest.TestCase): inst.request = preq preq.completed = True preq.empty = True - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() + inst.received(b'GET / HTTP/1.1\n\n') self.assertEqual(inst.request, None) self.assertEqual(inst.server.tasks, []) @@ -265,11 +238,10 @@ class TestHTTPChannel(unittest.TestCase): inst.request = preq preq.completed = True preq.error = True - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() + inst.received(b'GET / HTTP/1.1\n\n') self.assertEqual(inst.request, None) self.assertEqual(len(inst.server.tasks), 1) - self.assertTrue(inst.task) + self.assertTrue(inst.requests) def test_received_preq_completed_connection_close(self): inst, sock, map = self._makeOneWithMap() @@ -279,12 +251,9 @@ class TestHTTPChannel(unittest.TestCase): preq.completed = True preq.empty = True preq.connection_close = True - inbuf = inst.inbuf - inbuf.append(b'GET / HTTP/1.1\n\n' + b'a' * 50000) - inst.received() + inst.received(b'GET / HTTP/1.1\n\n' + b'a' * 50000) self.assertEqual(inst.request, None) self.assertEqual(inst.server.tasks, []) - self.assertNotEqual(inst.inbuf, inbuf) def test_received_preq_completed_n_lt_data(self): inst, sock, map = self._makeOneWithMap() @@ -292,13 +261,13 @@ class TestHTTPChannel(unittest.TestCase): preq = DummyParser() inst.request = preq preq.completed = True - preq.empty = True - preq.retval = 1 - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() + preq.empty = False + line = b'GET / HTTP/1.1\n\n' + preq.retval = len(line) + inst.received(line + line) self.assertEqual(inst.request, None) - self.assertEqual(inst.task, None) - self.assertEqual(inst.server.tasks, []) + self.assertEqual(len(inst.requests), 2) + self.assertEqual(len(inst.server.tasks), 1) def test_received_headers_finished_not_expect_continue(self): inst, sock, map = self._makeOneWithMap() @@ -310,8 +279,7 @@ class TestHTTPChannel(unittest.TestCase): preq.completed = False preq.empty = False preq.retval = 1 - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() + inst.received(b'GET / HTTP/1.1\n\n') self.assertEqual(inst.request, preq) self.assertEqual(inst.server.tasks, []) self.assertEqual(inst.outbuf.get(100), b'') @@ -325,111 +293,136 @@ class TestHTTPChannel(unittest.TestCase): preq.headers_finished = True preq.completed = False preq.empty = False - preq.retval = 1 - inst.inbuf.append(b'GET / HTTP/1.1\n\n') - inst.received() + inst.received(b'GET / HTTP/1.1\n\n') self.assertEqual(inst.request, preq) self.assertEqual(inst.server.tasks, []) - self.assertEqual(inst.outbuf.get(100), b'HTTP/1.1 100 Continue\r\n\r\n') - - def test_handle_error_reraises_SystemExit(self): - inst, sock, map = self._makeOneWithMap() - self.assertRaises(SystemExit, - inst.handle_error, (SystemExit, None, None)) + self.assertEqual(sock.sent, b'HTTP/1.1 100 Continue\r\n\r\n') + self.assertEqual(inst.sent_expect_continue, True) + self.assertEqual(preq.completed, False) - def test_handle_error_reraises_KeyboardInterrupt(self): + def test_service_no_requests(self): inst, sock, map = self._makeOneWithMap() - self.assertRaises(KeyboardInterrupt, - inst.handle_error, (KeyboardInterrupt, None, None)) - - def test_handle_error_noreraise(self): - inst, sock, map = self._makeOneWithMap() - # compact_traceback throws an AssertionError without a traceback - self.assertRaises(AssertionError, inst.handle_error, - (ValueError, ValueError('a'), None)) + inst.requests = [] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(inst.force_flush) + self.assertTrue(inst.last_activity) - def test_handle_comm_error_log(self): + def test_service_with_one_request(self): inst, sock, map = self._makeOneWithMap() - inst.adj.log_socket_errors = True - # compact_traceback throws an AssertionError without a traceback - self.assertRaises(AssertionError, inst.handle_comm_error) + request = DummyRequest() + inst.task_class = DummyTaskClass() + inst.requests = [request] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request.serviced) - def test_handle_comm_error_no(self): + def test_service_with_one_error_request(self): inst, sock, map = self._makeOneWithMap() - inst.adj.log_socket_errors = False - inst.handle_comm_error() - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) + request = DummyRequest() + request.error = DummyError() + inst.error_task_class = DummyTaskClass() + inst.requests = [request] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request.serviced) - def test_service_no_task(self): + def test_service_with_multiple_requests(self): inst, sock, map = self._makeOneWithMap() - inst.task = None + request1 = DummyRequest() + request2 = DummyRequest() + inst.task_class = DummyTaskClass() + inst.requests = [request1, request2] inst.service() - self.assertEqual(inst.task, None) + self.assertEqual(inst.requests, []) + self.assertTrue(request1.serviced) + self.assertTrue(request2.serviced) - def test_service_with_task(self): + def test_service_with_request_raises(self): inst, sock, map = self._makeOneWithMap() - task = DummyTask() - inst.task = task + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() inst.service() - self.assertEqual(inst.task, None) - self.assertTrue(task.serviced) + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.force_flush) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, True) - def test_service_with_task_raises(self): + def test_service_with_requests_raises_already_wrote_header(self): inst, sock, map = self._makeOneWithMap() inst.adj.expose_tracebacks = False inst.server = DummyServer() - task = DummyTask(ValueError) - task.wrote_header = False - inst.task = task + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.error_task_class = DummyTaskClass() inst.logger = DummyLogger() inst.service() - self.assertTrue(task.serviced) - self.assertEqual(inst.task, None) + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(sock.sent) + self.assertTrue(inst.force_flush) + self.assertTrue(inst.last_activity) + self.assertTrue(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, False) - def test_service_with_task_raises_with_expose_tbs(self): + def test_service_with_requests_raises_didnt_write_header_expose_tbs(self): inst, sock, map = self._makeOneWithMap() inst.adj.expose_tracebacks = True inst.server = DummyServer() - task = DummyTask(ValueError) - task.wrote_header = False - inst.task = task + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass() inst.logger = DummyLogger() inst.service() - self.assertTrue(task.serviced) - self.assertEqual(inst.task, None) + self.assertTrue(request.serviced) + self.assertFalse(inst.will_close) + self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(sock.sent) + self.assertTrue(inst.force_flush) + self.assertTrue(inst.last_activity) + self.assertEqual(inst.error_task_class.serviced, True) - def test_service_with_task_raises_already_wrote_header(self): + def test_service_with_requests_raises_didnt_write_header(self): inst, sock, map = self._makeOneWithMap() inst.adj.expose_tracebacks = False inst.server = DummyServer() - task = DummyTask(ValueError) - task.wrote_header = True - inst.task = task + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False inst.logger = DummyLogger() inst.service() - self.assertTrue(task.serviced) - self.assertEqual(inst.task, None) + self.assertTrue(request.serviced) + self.assertTrue(inst.will_close) + self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertFalse(sock.sent) + self.assertTrue(inst.force_flush) + self.assertTrue(inst.last_activity) + self.assertTrue(inst.will_close) - def test_cancel_no_task(self): + def test_cancel_no_requests(self): inst, sock, map = self._makeOneWithMap() - inst.task = None + inst.requests = () inst.cancel() - self.assertEqual(inst.task, None) + self.assertEqual(inst.requests, []) - def test_cancel_with_task(self): + def test_cancel_with_requests(self): inst, sock, map = self._makeOneWithMap() - task = DummyTask() - inst.task = task + inst.requests = [None] inst.cancel() - self.assertEqual(inst.task, None) - self.assertEqual(task.cancelled, True) + self.assertEqual(inst.requests, []) def test_defer(self): inst, sock, map = self._makeOneWithMap() @@ -452,10 +445,21 @@ class DummySock(object): self.sent += data return len(data) +class DummyLock(object): + def __init__(self, acquirable=True): + self.acquirable = acquirable + def acquire(self, val): + self.val = val + self.acquired = True + return self.acquirable + def release(self): + self.released = True + class DummyBuffer(object): def __init__(self, data, toraise=None): self.data = data self.toraise = toraise + self.lock = DummyLock() def get(self, *arg): if self.toraise: @@ -467,6 +471,9 @@ class DummyBuffer(object): def skip(self, num, x): self.skipped = num + def __len__(self): + return len(self.data) + class DummyAdjustments(object): outbuf_overflow = 1048576 inbuf_overflow = 512000 @@ -478,6 +485,7 @@ class DummyAdjustments(object): recv_bytes = 8192 expose_tracebacks = True ident = 'waitress' + max_request_header_size = 10000 class DummyServer(object): trigger_pulled = False @@ -507,25 +515,39 @@ class DummyParser(object): return len(data) class DummyRequest(object): + error = None uri = 'http://example.com' + version = '1.0' + def __init__(self): + self.headers = {} -class DummyTask(object): - serviced = False - cancelled = False +class DummyLogger(object): + def __init__(self): + self.exceptions = [] + def exception(self, msg): + self.exceptions.append(msg) + +class DummyError(object): + code = '431' + reason = 'Bleh' + body = 'My body' + +class DummyTaskClass(object): + wrote_header = True close_on_finish = False - request = DummyRequest() - wrote_header = False + serviced = False + def __init__(self, toraise=None): self.toraise = toraise + + def __call__(self, channel, request): + self.request = request + return self + def service(self): self.serviced = True + self.request.serviced = True if self.toraise: raise self.toraise - def cancel(self): - self.cancelled = True -class DummyLogger(object): - def __init__(self): - self.exceptions = [] - def exception(self, msg): - self.exceptions.append(msg) + diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py index f5081cf..8676740 100644 --- a/waitress/tests/test_server.py +++ b/waitress/tests/test_server.py @@ -171,7 +171,7 @@ class TestWSGIServer(unittest.TestCase): def test_maintenance(self): inst = self._makeOneWithMap() class DummyChannel(object): - task = None + requests = [] zombie = DummyChannel() zombie.last_activity = 0 zombie.running_tasks = False diff --git a/waitress/tests/test_task.py b/waitress/tests/test_task.py index f5fb95c..6b1d63d 100644 --- a/waitress/tests/test_task.py +++ b/waitress/tests/test_task.py @@ -601,7 +601,7 @@ class DummyChannel(object): server = DummyServer() self.server = server self.written = b'' - def write(self, data): + def write_soon(self, data): self.written += data return len(data) |