diff options
author | Michael Merickel <michael@merickel.org> | 2019-04-07 16:28:46 -0500 |
---|---|---|
committer | Michael Merickel <michael@merickel.org> | 2019-04-10 02:42:37 -0500 |
commit | 0b1bd775cb3b7def226b6e25a2ece1e3d8629129 (patch) | |
tree | e5001a97b1f8463e123ff28566fc1e078d9de26f | |
parent | ab4ff97f6aad5b93188cc73a7e01ce4ea6be8df9 (diff) | |
download | waitress-0b1bd775cb3b7def226b6e25a2ece1e3d8629129.tar.gz |
maintain our own byte buffer for data that didn't get writtenseekable
-rw-r--r-- | waitress/buffers.py | 271 | ||||
-rw-r--r-- | waitress/channel.py | 104 | ||||
-rw-r--r-- | waitress/receiver.py | 4 | ||||
-rw-r--r-- | waitress/task.py | 2 | ||||
-rw-r--r-- | waitress/tests/test_buffers.py | 489 | ||||
-rw-r--r-- | waitress/tests/test_channel.py | 76 | ||||
-rw-r--r-- | waitress/tests/test_functional.py | 1 |
7 files changed, 394 insertions, 553 deletions
diff --git a/waitress/buffers.py b/waitress/buffers.py index cacc094..42d5751 100644 --- a/waitress/buffers.py +++ b/waitress/buffers.py @@ -14,6 +14,7 @@ """Buffers """ from io import BytesIO +from tempfile import TemporaryFile # copy_bytes controls the size of temp. strings for shuffling data around. COPY_BYTES = 1 << 18 # 256K @@ -22,157 +23,110 @@ COPY_BYTES = 1 << 18 # 256K STRBUF_LIMIT = 8192 class FileBasedBuffer(object): + seekable = True + remaining = 0 # -1 would indicate an infinite stream - remain = 0 + def __bool__(self): + return self.remaining != 0 - def __init__(self, file, from_buffer=None): - self.file = file - if from_buffer is not None: - from_file = from_buffer.getfile() - read_pos = from_file.tell() - from_file.seek(0) - while True: - data = from_file.read(COPY_BYTES) - if not data: - break - file.write(data) - self.remain = int(file.tell() - read_pos) - from_file.seek(read_pos) - file.seek(read_pos) - - def __len__(self): - return self.remain - - def __nonzero__(self): - return True - - __bool__ = __nonzero__ # py3 + __nonzero = __bool__ # py2 def append(self, s): + assert self.seekable file = self.file read_pos = file.tell() file.seek(0, 2) file.write(s) file.seek(read_pos) - self.remain = self.remain + len(s) + self.remaining += len(s) - def get(self, numbytes=-1, skip=False): + def read(self, numbytes=-1): file = self.file - if not skip: - read_pos = file.tell() + remaining = self.remaining + if remaining != -1 and numbytes > remaining: + numbytes = remaining if numbytes < 0: # Read all res = file.read() else: res = file.read(numbytes) - if skip: - self.remain -= len(res) + numres = len(res) + if remaining == -1: + # keep remaining at -1 until EOF + if not numres and numbytes != 0: + self.remaining = 0 else: - file.seek(read_pos) + self.remaining -= numres return res - def skip(self, numbytes, allow_prune=0): - if self.remain < numbytes: - raise ValueError("Can't skip %d bytes in buffer of %d bytes" % ( - numbytes, self.remain) - ) - self.file.seek(numbytes, 1) - self.remain = self.remain - numbytes - - def newfile(self): - raise NotImplementedError() - - def prune(self): - file = self.file - if self.remain == 0: - read_pos = file.tell() - file.seek(0, 2) - sz = file.tell() - file.seek(read_pos) - if sz == 0: - # Nothing to prune. - return - nf = self.newfile() - while True: - data = file.read(COPY_BYTES) - if not data: - break - nf.write(data) - self.file = nf - - def getfile(self): - return self.file + def rollback(self, numbytes): + assert self.seekable + self.file.seek(-numbytes, 1) + self.remaining += numbytes def close(self): + self.remaining = 0 if hasattr(self.file, 'close'): self.file.close() - self.remain = 0 class TempfileBasedBuffer(FileBasedBuffer): def __init__(self, from_buffer=None): - FileBasedBuffer.__init__(self, self.newfile(), from_buffer) - - def newfile(self): - from tempfile import TemporaryFile - return TemporaryFile('w+b') + file = TemporaryFile('w+b') + if from_buffer is not None: + while True: + data = from_buffer.read(COPY_BYTES) + if not data: + break + file.write(data) + self.remaining += len(data) + file.seek(0) + self.file = file class BytesIOBasedBuffer(FileBasedBuffer): - def __init__(self, from_buffer=None): - if from_buffer is not None: - FileBasedBuffer.__init__(self, BytesIO(), from_buffer) - else: - # Shortcut. :-) - self.file = BytesIO() + def __init__(self, value=None): + self.file = BytesIO(value) + if value is not None: + self.remaining = len(value) - def newfile(self): - return BytesIO() +def _is_seekable(fp): + if hasattr(fp, 'seekable'): + return fp.seekable() + return hasattr(fp, 'seek') and hasattr(fp, 'tell') class ReadOnlyFileBasedBuffer(FileBasedBuffer): # used as wsgi.file_wrapper + remaining = -1 def __init__(self, file, block_size=32768): self.file = file self.block_size = block_size # for __iter__ + self.seekable = _is_seekable(file) def prepare(self, size=None): - if hasattr(self.file, 'seek') and hasattr(self.file, 'tell'): + if self.seekable: start_pos = self.file.tell() self.file.seek(0, 2) end_pos = self.file.tell() self.file.seek(start_pos) fsize = end_pos - start_pos if size is None: - self.remain = fsize + self.remaining = fsize else: - self.remain = min(fsize, size) - return self.remain - - def get(self, numbytes=-1, skip=False): - # never read more than self.remain (it can be user-specified) - if numbytes == -1 or numbytes > self.remain: - numbytes = self.remain - file = self.file - if not skip: - read_pos = file.tell() - res = file.read(numbytes) - if skip: - self.remain -= len(res) - else: - file.seek(read_pos) - return res + self.remaining = min(fsize, size) + return self.remaining def __iter__(self): # called by task if self.filelike has no seek/tell return self - def next(self): + def __next__(self): val = self.file.read(self.block_size) if not val: raise StopIteration return val - __next__ = next # py3 + next = __next__ # py2 def append(self, s): raise NotImplementedError @@ -187,112 +141,71 @@ class OverflowableBuffer(object): The first two stages are fastest for simple transfers. """ + seekable = True + remaining = 0 + overflowed = False buf = None strbuf = b'' # Bytes-based buffer. def __init__(self, overflow): - # overflow is the maximum to be stored in a StringIO buffer. + # overflow is the maximum to be stored in a BytesIO buffer. self.overflow = overflow - def __len__(self): - buf = self.buf - if buf is not None: - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - return buf.__len__() - else: - return self.strbuf.__len__() - - def __nonzero__(self): - # use self.__len__ rather than len(self) FBO of not getting - # OverflowError on Python 2 - return self.__len__() > 0 - - __bool__ = __nonzero__ # py3 + def __bool__(self): + return self.remaining != 0 - def _create_buffer(self): - strbuf = self.strbuf - if len(strbuf) >= self.overflow: - self._set_large_buffer() - else: - self._set_small_buffer() - buf = self.buf - if strbuf: - buf.append(self.strbuf) - self.strbuf = b'' - return buf - - def _set_small_buffer(self): - self.buf = BytesIOBasedBuffer(self.buf) - self.overflowed = False - - def _set_large_buffer(self): - self.buf = TempfileBasedBuffer(self.buf) - self.overflowed = True + __nonzero = __bool__ # py2 def append(self, s): buf = self.buf if buf is None: - strbuf = self.strbuf + strbuf = self.strbuf if self.remaining else b'' if len(strbuf) + len(s) < STRBUF_LIMIT: self.strbuf = strbuf + s + self.remaining += len(s) return - buf = self._create_buffer() - buf.append(s) - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - sz = buf.__len__() - if not self.overflowed: - if sz >= self.overflow: - self._set_large_buffer() - - def get(self, numbytes=-1, skip=False): - buf = self.buf - if buf is None: - strbuf = self.strbuf - if not skip: - return strbuf - buf = self._create_buffer() - return buf.get(numbytes, skip) - - def skip(self, numbytes, allow_prune=False): - buf = self.buf - if buf is None: - if allow_prune and numbytes == len(self.strbuf): - # We could slice instead of converting to - # a buffer, but that would eat up memory in - # large transfers. + else: + buf = BytesIOBasedBuffer(self.strbuf + s) + self.buf = buf self.strbuf = b'' - return - buf = self._create_buffer() - buf.skip(numbytes, allow_prune) - - def prune(self): - """ - A potentially expensive operation that removes all data - already retrieved from the buffer. - """ + else: + buf.append(s) + remaining = buf.remaining + self.remaining = remaining + if not self.overflowed and remaining > self.overflow: + self.buf = TempfileBasedBuffer(buf) + self.overflowed = True + + def read(self, numbytes=-1): buf = self.buf if buf is None: - self.strbuf = b'' - return - buf.prune() - if self.overflowed: - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - sz = buf.__len__() - if sz < self.overflow: - # Revert to a faster buffer. - self._set_small_buffer() - - def getfile(self): + if self.remaining <= numbytes or numbytes == -1: + self.remaining = 0 + return self.strbuf + buf = self.buf = BytesIOBasedBuffer(self.strbuf) + data = buf.read(numbytes) + self.remaining = buf.remaining + return data + + def rollback(self, numbytes): buf = self.buf if buf is None: - buf = self._create_buffer() - return buf.getfile() + self.strbuf = self.strbuf[-numbytes:] + self.remaining = len(self.strbuf) + return + buf.rollback(numbytes) + self.remaining = buf.remaining def close(self): + self.remaining = 0 + self.strbuf = b'' buf = self.buf if buf is not None: buf.close() + + def getfile(self): + buf = self.buf + if buf is None: + buf = self.buf = BytesIOBasedBuffer(self.strbuf) + return buf.file diff --git a/waitress/channel.py b/waitress/channel.py index 5bacaa0..c53e4fe 100644 --- a/waitress/channel.py +++ b/waitress/channel.py @@ -18,6 +18,7 @@ import time import traceback from waitress.buffers import ( + BytesIOBasedBuffer, OverflowableBuffer, ReadOnlyFileBasedBuffer, ) @@ -54,7 +55,9 @@ class HTTPChannel(wasyncore.dispatcher, object): close_when_flushed = False # set to True to close the socket when flushed requests = () # currently pending requests sent_continue = False # used as a latch after sending 100 continue - total_outbufs_len = 0 # total bytes ready to send + known_outbufs_len = 0 # total known bytes ready to send + has_unseekable_outbufs = False # any unseekable data to send + has_outbuf_data = False # any data to write including unseekable current_outbuf_count = 0 # total bytes written to current outbuf # @@ -90,7 +93,7 @@ class HTTPChannel(wasyncore.dispatcher, object): # if there's data in the out buffer or we've been instructed to close # the channel (possibly by our server maintenance logic), run # handle_write - return self.total_outbufs_len or self.will_close + return self.has_outbuf_data or self.will_close def handle_write(self): # Precondition: there's data in the out buffer to be sent, or @@ -107,7 +110,10 @@ class HTTPChannel(wasyncore.dispatcher, object): # because it's either data left over from task output # or a 100 Continue line sent within "received". flush = self._flush_some - elif self.total_outbufs_len >= self.adj.send_bytes: + elif ( + self.known_outbufs_len >= self.adj.send_bytes + or self.has_unseekable_outbufs + ): # 1. There's a running task, so we need to try to lock # the outbuf before sending # 2. Only try to send if the data in the out buffer is larger @@ -129,7 +135,7 @@ class HTTPChannel(wasyncore.dispatcher, object): self.logger.exception('Unexpected exception when flushing') self.will_close = True - if self.close_when_flushed and not self.total_outbufs_len: + if self.close_when_flushed and not self.has_outbuf_data: self.close_when_flushed = False self.will_close = True @@ -142,7 +148,7 @@ class HTTPChannel(wasyncore.dispatcher, object): # 2. There's no already currently running task(s). # 3. There's no data in the output buffer that needs to be sent # before we potentially create a new task. - return not (self.will_close or self.requests or self.total_outbufs_len) + return not (self.will_close or self.requests or self.has_outbuf_data) def handle_read(self): try: @@ -181,7 +187,8 @@ class HTTPChannel(wasyncore.dispatcher, object): outbuf_payload = b'HTTP/1.1 100 Continue\r\n\r\n' self.outbufs[-1].append(outbuf_payload) self.current_outbuf_count += len(outbuf_payload) - self.total_outbufs_len += len(outbuf_payload) + self.known_outbufs_len += len(outbuf_payload) + self.has_outbuf_data = True self.sent_continue = True self._flush_some() request.completed = False @@ -210,7 +217,7 @@ class HTTPChannel(wasyncore.dispatcher, object): try: self._flush_some() - if self.total_outbufs_len < self.adj.outbuf_high_watermark: + if self.known_outbufs_len < self.adj.outbuf_high_watermark: self.outbuf_lock.notify() finally: self.outbuf_lock.release() @@ -220,28 +227,38 @@ class HTTPChannel(wasyncore.dispatcher, object): sent = 0 dobreak = False + outbufs = self.outbufs while True: - outbuf = self.outbufs[0] - # use outbuf.__len__ rather than len(outbuf) FBO of not getting - # OverflowError on 32-bit Python - outbuflen = outbuf.__len__() - while outbuflen > 0: - chunk = outbuf.get(self.sendbuf_len) + outbuf = outbufs[0] + # remaining might be -1 for an unseekable ROFBB + # so we perform a read and assume that the ROFBB will update + # remaining when it knows it's empty + while outbuf.remaining != 0: + chunk = outbuf.read(self.sendbuf_len) + num_tosend = len(chunk) num_sent = self.send(chunk) - if num_sent: - outbuf.skip(num_sent, True) - outbuflen -= num_sent - sent += num_sent - self.total_outbufs_len -= num_sent - else: + # handle_close may have been called by send() so be careful + # about mutating state below if num_sent is 0 + sent += num_sent + if num_sent < num_tosend and self.connected: + # failed to write all of the data, so either put the + # remaining amount into a new buffer to be used on the + # next write or rollback the pointer to only skip what was + # successfully written + if outbuf.seekable: + outbuf.rollback(num_tosend - num_sent) + else: + outbuf = BytesIOBasedBuffer(chunk[num_sent:]) + outbufs.appendleft(outbuf) + if not num_sent: # failed to write anything, break out entirely dobreak = True break else: # self.outbufs[-1] must always be a writable outbuf - if len(self.outbufs) > 1: - toclose = self.outbufs.popleft() + if len(outbufs) > 1: + toclose = outbufs.popleft() try: toclose.close() except Exception: @@ -254,25 +271,44 @@ class HTTPChannel(wasyncore.dispatcher, object): if dobreak: break + # refresh the outbuf statistics after a write + self._scan_outbufs() + if sent: self.last_activity = time.time() return True return False + def _scan_outbufs(self): + self.has_unseekable_outbufs = False + self.known_outbufs_len = 0 + for o in self.outbufs: + if o.seekable: + self.known_outbufs_len += o.remaining + else: + self.has_unseekable_outbufs = True + self.has_outbuf_data = ( + self.known_outbufs_len or self.has_unseekable_outbufs + ) + def handle_close(self): with self.outbuf_lock: - while self.outbufs: - outbuf = self.outbufs.popleft() + outbufs = self.outbufs + while outbufs: + toclose = outbufs.popleft() try: - outbuf.close() + toclose.close() except Exception: self.logger.exception( 'Unknown exception while trying to close outbuf') - self.total_outbufs_len = 0 + self.known_outbufs_len = 0 + self.has_outbuf_data = False + self.has_unseekable_outbufs = False + self.current_outbuf_count = 0 self.connected = False self.outbuf_lock.notify() - wasyncore.dispatcher.close(self) + self.close() def add_channel(self, map=None): """See wasyncore.dispatcher @@ -315,6 +351,11 @@ class HTTPChannel(wasyncore.dispatcher, object): nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) self.outbufs.append(nextbuf) self.current_outbuf_count = 0 + num_bytes = data.remaining + if num_bytes == -1: + self.has_unseekable_outbufs = True + else: + self.known_outbufs_len += num_bytes else: if self.current_outbuf_count > self.adj.outbuf_high_watermark: # rotate to a new buffer if the current buffer has hit @@ -323,9 +364,10 @@ class HTTPChannel(wasyncore.dispatcher, object): self.outbufs.append(nextbuf) self.current_outbuf_count = 0 self.outbufs[-1].append(data) - num_bytes = len(data) - self.current_outbuf_count += num_bytes - self.total_outbufs_len += num_bytes + num_bytes = len(data) + self.current_outbuf_count += num_bytes + self.known_outbufs_len += num_bytes + self.has_outbuf_data = True # XXX We might eventually need to pull the trigger here (to # instruct select to stop blocking), but it slows things down so # much that I'll hold off for now; "server push" on otherwise @@ -335,11 +377,11 @@ class HTTPChannel(wasyncore.dispatcher, object): def _flush_outbufs_below_high_watermark(self): # check first to avoid locking if possible - if self.total_outbufs_len > self.adj.outbuf_high_watermark: + if self.known_outbufs_len > self.adj.outbuf_high_watermark: with self.outbuf_lock: while ( self.connected and - self.total_outbufs_len > self.adj.outbuf_high_watermark + self.known_outbufs_len > self.adj.outbuf_high_watermark ): self.outbuf_lock.wait() diff --git a/waitress/receiver.py b/waitress/receiver.py index 594ae97..3cc44fc 100644 --- a/waitress/receiver.py +++ b/waitress/receiver.py @@ -29,7 +29,7 @@ class FixedStreamReceiver(object): self.buf = buf def __len__(self): - return self.buf.__len__() + return self.buf.remaining def received(self, data): 'See IStreamConsumer' @@ -70,7 +70,7 @@ class ChunkedReceiver(object): self.buf = buf def __len__(self): - return self.buf.__len__() + return self.buf.remaining def received(self, s): # Returns the number of bytes consumed. diff --git a/waitress/task.py b/waitress/task.py index 3b7d332..c0d67ca 100644 --- a/waitress/task.py +++ b/waitress/task.py @@ -451,7 +451,7 @@ class WSGITask(Task): if app_iter.__class__ is ReadOnlyFileBasedBuffer: cl = self.content_length size = app_iter.prepare(cl) - if size: + if size > 0: if cl != size: if cl is not None: self.remove_content_length_header() diff --git a/waitress/tests/test_buffers.py b/waitress/tests/test_buffers.py index 46a215e..e614d22 100644 --- a/waitress/tests/test_buffers.py +++ b/waitress/tests/test_buffers.py @@ -1,221 +1,186 @@ import unittest import io -class TestFileBasedBuffer(unittest.TestCase): - - def _makeOne(self, file=None, from_buffer=None): - from waitress.buffers import FileBasedBuffer - return FileBasedBuffer(file, from_buffer=from_buffer) - - def test_ctor_from_buffer_None(self): - inst = self._makeOne('file') - self.assertEqual(inst.file, 'file') - - def test_ctor_from_buffer(self): - from_buffer = io.BytesIO(b'data') - from_buffer.getfile = lambda *x: from_buffer - f = io.BytesIO() - inst = self._makeOne(f, from_buffer) - self.assertEqual(inst.file, f) - del from_buffer.getfile - self.assertEqual(inst.remain, 4) - from_buffer.close() - - def test___len__(self): - inst = self._makeOne() - inst.remain = 10 - self.assertEqual(len(inst), 10) +class FileBasedBufferTests(object): - def test___nonzero__(self): - inst = self._makeOne() - inst.remain = 10 + def test_seekable(self): + inst = self._makeOneFromBytes() + self.assertTrue(inst.seekable) + self.assertEqual(inst.remaining, 0) + + def test___bool__(self): + inst = self._makeOneFromBytes() + inst.remaining = 10 self.assertEqual(bool(inst), True) - inst.remain = 0 + inst.remaining = 0 + self.assertEqual(bool(inst), False) + inst.remaining = -1 self.assertEqual(bool(inst), True) def test_append(self): - f = io.BytesIO(b'data') - inst = self._makeOne(f) + inst = self._makeOneFromBytes(b'data') inst.append(b'data2') - self.assertEqual(f.getvalue(), b'datadata2') - self.assertEqual(inst.remain, 5) - - def test_get_skip_true(self): - f = io.BytesIO(b'data') - inst = self._makeOne(f) - result = inst.get(100, skip=True) + self.assertEqual(inst.remaining, 9) + self.assertEqual(inst.read(), b'datadata2') + self.assertEqual(inst.remaining, 0) + + def test_read_zero(self): + inst = self._makeOneFromBytes(b'data') + result = inst.read(0) + self.assertEqual(result, b'') + self.assertEqual(inst.remaining, 4) + + def test_read_all(self): + inst = self._makeOneFromBytes(b'data') + result = inst.read() self.assertEqual(result, b'data') - self.assertEqual(inst.remain, -4) + self.assertEqual(inst.remaining, 0) - def test_get_skip_false(self): - f = io.BytesIO(b'data') - inst = self._makeOne(f) - result = inst.get(100, skip=False) - self.assertEqual(result, b'data') - self.assertEqual(inst.remain, 0) + def test_read_not_enough(self): + inst = self._makeOneFromBytes(b'data') + result = inst.read(3) + self.assertEqual(result, b'dat') + self.assertEqual(inst.remaining, 1) - def test_get_skip_bytes_less_than_zero(self): - f = io.BytesIO(b'data') - inst = self._makeOne(f) - result = inst.get(-1, skip=False) + def test_read_exact(self): + inst = self._makeOneFromBytes(b'data') + result = inst.read(4) self.assertEqual(result, b'data') - self.assertEqual(inst.remain, 0) - - def test_skip_remain_gt_bytes(self): - f = io.BytesIO(b'd') - inst = self._makeOne(f) - inst.remain = 1 - inst.skip(1) - self.assertEqual(inst.remain, 0) - - def test_skip_remain_lt_bytes(self): - f = io.BytesIO(b'd') - inst = self._makeOne(f) - inst.remain = 1 - self.assertRaises(ValueError, inst.skip, 2) + self.assertEqual(inst.remaining, 0) - def test_newfile(self): - inst = self._makeOne() - self.assertRaises(NotImplementedError, inst.newfile) - - def test_prune_remain_notzero(self): - f = io.BytesIO(b'd') - inst = self._makeOne(f) - inst.remain = 1 - nf = io.BytesIO() - inst.newfile = lambda *x: nf - inst.prune() - self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b'd') - - def test_prune_remain_zero_tell_notzero(self): - f = io.BytesIO(b'd') - inst = self._makeOne(f) - nf = io.BytesIO(b'd') - inst.newfile = lambda *x: nf - inst.remain = 0 - inst.prune() - self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b'd') - - def test_prune_remain_zero_tell_zero(self): - f = io.BytesIO() - inst = self._makeOne(f) - inst.remain = 0 - inst.prune() - self.assertTrue(inst.file is f) + def test_read_too_much(self): + inst = self._makeOneFromBytes(b'data') + result = inst.read(100) + self.assertEqual(result, b'data') + self.assertEqual(inst.remaining, 0) + + def test_rollback(self): + inst = self._makeOneFromBytes(b'data') + self.assertEqual(inst.remaining, 4) + result = inst.read(3) + self.assertEqual(inst.remaining, 1) + self.assertEqual(result, b'dat') + inst.rollback(len(result)) + self.assertEqual(inst.remaining, 4) + result = inst.read() + self.assertEqual(inst.remaining, 0) + self.assertEqual(result, b'data') def test_close(self): - f = io.BytesIO() - inst = self._makeOne(f) + inst = self._makeOneFromBytes() inst.close() - self.assertTrue(f.closed) + self.assertEqual(inst.remaining, 0) -class TestTempfileBasedBuffer(unittest.TestCase): +class TestTempfileBasedBuffer(FileBasedBufferTests, unittest.TestCase): def _makeOne(self, from_buffer=None): from waitress.buffers import TempfileBasedBuffer - return TempfileBasedBuffer(from_buffer=from_buffer) + buffer = TempfileBasedBuffer(from_buffer=from_buffer) + self.buffers.append(buffer) + return buffer - def test_newfile(self): - inst = self._makeOne() - r = inst.newfile() - self.assertTrue(hasattr(r, 'fileno')) # file + def _makeOneFromBytes(self, from_bytes=None): + return self._makeOne(from_buffer=io.BytesIO(from_bytes)) -class TestBytesIOBasedBuffer(unittest.TestCase): + def setUp(self): + self.buffers = [] - def _makeOne(self, from_buffer=None): - from waitress.buffers import BytesIOBasedBuffer - return BytesIOBasedBuffer(from_buffer=from_buffer) + def tearDown(self): + for b in self.buffers: + b.close() - def test_ctor_from_buffer_not_None(self): - f = io.BytesIO() - f.getfile = lambda *x: f - inst = self._makeOne(f) - self.assertTrue(hasattr(inst.file, 'read')) +class TestBytesIOBasedBuffer(FileBasedBufferTests, unittest.TestCase): - def test_ctor_from_buffer_None(self): - inst = self._makeOne() - self.assertTrue(hasattr(inst.file, 'read')) + def _makeOne(self, from_bytes=None): + from waitress.buffers import BytesIOBasedBuffer + return BytesIOBasedBuffer(from_bytes) - def test_newfile(self): - inst = self._makeOne() - r = inst.newfile() - self.assertTrue(hasattr(r, 'read')) + _makeOneFromBytes = _makeOne -class TestReadOnlyFileBasedBuffer(unittest.TestCase): +class TestReadOnlyFileBasedBuffer(FileBasedBufferTests, unittest.TestCase): - def _makeOne(self, file, block_size=8192): + def _makeOne(self, file, block_size=32768): from waitress.buffers import ReadOnlyFileBasedBuffer - return ReadOnlyFileBasedBuffer(file, block_size) + buffer = ReadOnlyFileBasedBuffer(file, block_size) + self.buffers.append(buffer) + return buffer - def test_prepare_not_seekable(self): - f = KindaFilelike(b'abc') - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, False) - self.assertEqual(inst.remain, 0) + def _makeOneFromBytes(self, from_bytes=None): + buffer = self._makeOne(io.BytesIO(from_bytes)) + buffer.prepare() + return buffer + + def setUp(self): + self.buffers = [] + + def tearDown(self): + for b in self.buffers: + b.close() - def test_prepare_not_seekable_closeable(self): - f = KindaFilelike(b'abc', close=1) + def test_append(self): # overrides FileBasedBufferTests.test_append + inst = self._makeOneFromBytes() + self.assertRaises(NotImplementedError, inst.append, 'a') + + def test_prepare_unseekable(self): + f = KindaFilelike(b'abc') inst = self._makeOne(f) result = inst.prepare() - self.assertEqual(result, False) - self.assertEqual(inst.remain, 0) - self.assertTrue(hasattr(inst, 'close')) + self.assertEqual(result, -1) + self.assertFalse(inst.seekable) + self.assertEqual(inst.remaining, -1) - def test_prepare_seekable_closeable(self): - f = Filelike(b'abc', close=1, tellresults=[0, 10]) + def test_prepare_seekable(self): + f = Filelike(b'abc', tellresults=[0, 10]) inst = self._makeOne(f) result = inst.prepare() self.assertEqual(result, 10) - self.assertEqual(inst.remain, 10) + self.assertTrue(inst.seekable) + self.assertEqual(inst.remaining, 10) self.assertEqual(inst.file.seeked, 0) - self.assertTrue(hasattr(inst, 'close')) - def test_get_numbytes_neg_one(self): - f = io.BytesIO(b'abcdef') + def test_prepare_maxsize_lt_len(self): + f = Filelike(b'abc', tellresults=[0, 10]) inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(-1) - self.assertEqual(result, b'ab') - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) + result = inst.prepare(3) + self.assertEqual(result, 3) + self.assertEqual(inst.remaining, 3) + self.assertTrue(inst.seekable) - def test_get_numbytes_gt_remain(self): - f = io.BytesIO(b'abcdef') + def test_prepare_maxsize_gt_len(self): + f = Filelike(b'abc', tellresults=[3, 10]) inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(3) - self.assertEqual(result, b'ab') - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) + result = inst.prepare(15) + self.assertEqual(result, 7) + self.assertEqual(inst.remaining, 7) + self.assertTrue(inst.seekable) - def test_get_numbytes_lt_remain(self): + def test_read_numbytes_neg_one(self): f = io.BytesIO(b'abcdef') + f.seek(4) inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(1) - self.assertEqual(result, b'a') - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) + inst.prepare() + self.assertEqual(inst.remaining, 2) + result = inst.read(-1) + self.assertEqual(result, b'ef') + self.assertEqual(inst.remaining, 0) + self.assertEqual(f.tell(), 6) - def test_get_numbytes_gt_remain_withskip(self): + def test_get_numbytes_gt_remain(self): f = io.BytesIO(b'abcdef') inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(3, skip=True) + inst.remaining = 2 + result = inst.read(3) self.assertEqual(result, b'ab') - self.assertEqual(inst.remain, 0) + self.assertEqual(inst.remaining, 0) self.assertEqual(f.tell(), 2) - def test_get_numbytes_lt_remain_withskip(self): + def test_get_numbytes_lt_remain(self): f = io.BytesIO(b'abcdef') inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(1, skip=True) + inst.remaining = 2 + result = inst.read(1) self.assertEqual(result, b'a') - self.assertEqual(inst.remain, 1) + self.assertEqual(inst.remaining, 1) self.assertEqual(f.tell(), 1) def test___iter__(self): @@ -227,61 +192,37 @@ class TestReadOnlyFileBasedBuffer(unittest.TestCase): r += val self.assertEqual(r, data) - def test_append(self): - inst = self._makeOne(None) - self.assertRaises(NotImplementedError, inst.append, 'a') + def test_unseekable_updates_remaining_at_eof(self): + f = io.BytesIO(b'abcdef') + inst = self._makeOne(f) + inst.remaining = -1 + result1 = inst.read() + result2 = inst.read() + self.assertEqual(result1, b'abcdef') + self.assertEqual(result2, b'') + self.assertEqual(inst.remaining, 0) + -class TestOverflowableBuffer(unittest.TestCase): +class TestOverflowableBuffer(FileBasedBufferTests, unittest.TestCase): def _makeOne(self, overflow=10): from waitress.buffers import OverflowableBuffer - return OverflowableBuffer(overflow) + buffer = OverflowableBuffer(overflow) + self.buffers.append(buffer) + return buffer - def test___len__buf_is_None(self): - inst = self._makeOne() - self.assertEqual(len(inst), 0) + def _makeOneFromBytes(self, from_bytes=None): + buffer = self._makeOne() + if from_bytes: + buffer.append(from_bytes) + return buffer - def test___len__buf_is_not_None(self): - inst = self._makeOne() - inst.buf = b'abc' - self.assertEqual(len(inst), 3) + def setUp(self): + self.buffers = [] - def test___nonzero__(self): - inst = self._makeOne() - inst.buf = b'abc' - self.assertEqual(bool(inst), True) - inst.buf = b'' - self.assertEqual(bool(inst), False) - - def test___nonzero___on_int_overflow_buffer(self): - inst = self._makeOne() - - class int_overflow_buf(bytes): - def __len__(self): - # maxint + 1 - return 0x7fffffffffffffff + 1 - inst.buf = int_overflow_buf() - self.assertEqual(bool(inst), True) - inst.buf = b'' - self.assertEqual(bool(inst), False) - - def test__create_buffer_large(self): - from waitress.buffers import TempfileBasedBuffer - inst = self._makeOne() - inst.strbuf = b'x' * 11 - inst._create_buffer() - self.assertEqual(inst.buf.__class__, TempfileBasedBuffer) - self.assertEqual(inst.buf.get(100), b'x' * 11) - self.assertEqual(inst.strbuf, b'') - - def test__create_buffer_small(self): - from waitress.buffers import BytesIOBasedBuffer - inst = self._makeOne() - inst.strbuf = b'x' * 5 - inst._create_buffer() - self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer) - self.assertEqual(inst.buf.get(100), b'x' * 5) - self.assertEqual(inst.strbuf, b'') + def tearDown(self): + for b in self.buffers: + b.close() def test_append_with_len_more_than_max_int(self): from waitress.compat import MAXINT @@ -289,123 +230,66 @@ class TestOverflowableBuffer(unittest.TestCase): inst.overflowed = True buf = DummyBuffer(length=MAXINT) inst.buf = buf + inst.remaining = MAXINT result = inst.append(b'x') # we don't want this to throw an OverflowError on Python 2 (see # https://github.com/Pylons/waitress/issues/47) self.assertEqual(result, None) - def test_append_buf_None_not_longer_than_srtbuf_limit(self): + def test_append_buf_None_not_longer_than_strbuf_limit(self): inst = self._makeOne() inst.strbuf = b'x' * 5 + inst.remaining = len(inst.strbuf) inst.append(b'hello') self.assertEqual(inst.strbuf, b'xxxxxhello') + self.assertEqual(inst.remaining, 10) def test_append_buf_None_longer_than_strbuf_limit(self): inst = self._makeOne(10000) inst.strbuf = b'x' * 8192 + inst.remaining = len(inst.strbuf) inst.append(b'hello') self.assertEqual(inst.strbuf, b'') - self.assertEqual(len(inst.buf), 8197) + self.assertEqual(inst.buf.remaining, 8197) def test_append_overflow(self): inst = self._makeOne(10) inst.strbuf = b'x' * 8192 + inst.remaining = len(inst.strbuf) inst.append(b'hello') self.assertEqual(inst.strbuf, b'') - self.assertEqual(len(inst.buf), 8197) + self.assertEqual(inst.buf.remaining, 8197) def test_append_sz_gt_overflow(self): from waitress.buffers import BytesIOBasedBuffer - f = io.BytesIO(b'data') - inst = self._makeOne(f) + inst = self._makeOne() buf = BytesIOBasedBuffer() inst.buf = buf inst.overflow = 2 inst.append(b'data2') - self.assertEqual(f.getvalue(), b'data') self.assertTrue(inst.overflowed) self.assertNotEqual(inst.buf, buf) - def test_get_buf_None_skip_False(self): - inst = self._makeOne() - inst.strbuf = b'x' * 5 - r = inst.get(5) - self.assertEqual(r, b'xxxxx') - - def test_get_buf_None_skip_True(self): - inst = self._makeOne() - inst.strbuf = b'x' * 5 - r = inst.get(5, skip=True) - self.assertFalse(inst.buf is None) - self.assertEqual(r, b'xxxxx') - - def test_skip_buf_None(self): - inst = self._makeOne() - inst.strbuf = b'data' - inst.skip(4) - self.assertEqual(inst.strbuf, b'') - self.assertNotEqual(inst.buf, None) - - def test_skip_buf_None_allow_prune_True(self): - inst = self._makeOne() - inst.strbuf = b'data' - inst.skip(4, True) - self.assertEqual(inst.strbuf, b'') - self.assertEqual(inst.buf, None) - - def test_prune_buf_None(self): - inst = self._makeOne() - inst.prune() - self.assertEqual(inst.strbuf, b'') - - def test_prune_with_buf(self): - inst = self._makeOne() - class Buf(object): - def prune(self): - self.pruned = True - inst.buf = Buf() - inst.prune() - self.assertEqual(inst.buf.pruned, True) - - def test_prune_with_buf_overflow(self): - inst = self._makeOne() - class DummyBuffer(io.BytesIO): - def getfile(self): - return self - def prune(self): - return True - def __len__(self): - return 5 - buf = DummyBuffer(b'data') - inst.buf = buf - inst.overflowed = True - inst.overflow = 10 - inst.prune() - self.assertNotEqual(inst.buf, buf) - - def test_prune_with_buflen_more_than_max_int(self): - from waitress.compat import MAXINT - inst = self._makeOne() - inst.overflowed = True - buf = DummyBuffer(length=MAXINT+1) - inst.buf = buf - result = inst.prune() - # we don't want this to throw an OverflowError on Python 2 (see - # https://github.com/Pylons/waitress/issues/47) - self.assertEqual(result, None) - - def test_getfile_buf_None(self): - inst = self._makeOne() - f = inst.getfile() - self.assertTrue(hasattr(f, 'read')) + def test_read_strbuf(self): + inst = self._makeOne(10) + inst.strbuf = b'x' + inst.remaining = len(inst.strbuf) + result = inst.read() + self.assertEqual(result, b'x') + self.assertEqual(inst.remaining, 0) - def test_getfile_buf_not_None(self): - inst = self._makeOne() - buf = io.BytesIO() - buf.getfile = lambda *x: buf - inst.buf = buf - f = inst.getfile() - self.assertEqual(f, buf) + def test_rollback_strbuf(self): + inst = self._makeOne(10) + inst.strbuf = b'x' + inst.remaining = len(inst.strbuf) + result = inst.read() + self.assertEqual(result, b'x') + self.assertEqual(inst.remaining, 0) + inst.rollback(1) + self.assertEqual(inst.remaining, 1) + result = inst.read() + self.assertEqual(result, b'x') + self.assertEqual(inst.remaining, 0) def test_close_nobuf(self): inst = self._makeOne() @@ -428,7 +312,7 @@ class KindaFilelike(object): self.bytes = bytes self.tellresults = tellresults if close is not None: - self.close = close + self.close = lambda: close class Filelike(KindaFilelike): @@ -441,13 +325,10 @@ class Filelike(KindaFilelike): class DummyBuffer(object): def __init__(self, length=0): - self.length = length - - def __len__(self): - return self.length + self.remaining = length def append(self, s): - self.length = self.length + len(s) + self.remaining = self.remaining + len(s) - def prune(self): - pass + def close(self): + self.closed = True diff --git a/waitress/tests/test_channel.py b/waitress/tests/test_channel.py index e2c7c49..8840b43 100644 --- a/waitress/tests/test_channel.py +++ b/waitress/tests/test_channel.py @@ -29,8 +29,9 @@ class TestHTTPChannel(unittest.TestCase): inst = cls(server, sock, '127.0.0.1', adj, map=map) if outbuf is not None: inst.outbufs = deque([outbuf]) - inst.total_outbufs_len = outbuf.__len__() - inst.current_outbuf_count = outbuf.__len__() + inst._scan_outbufs() + if outbuf.seekable: + inst.current_outbuf_count = outbuf.remaining inst.outbuf_lock = DummyLock() return inst, sock, map @@ -40,7 +41,7 @@ class TestHTTPChannel(unittest.TestCase): self.assertEqual(inst.sendbuf_len, 2048) self.assertEqual(map[100], inst) - def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self): + def test_known_outbufs_len_an_outbuf_size_gt_sys_maxint(self): from waitress.compat import MAXINT class DummyBuffer(object): chunks = [] @@ -51,15 +52,15 @@ class TestHTTPChannel(unittest.TestCase): return MAXINT inst, _, map = self._makeOne() inst.outbufs = deque([DummyBuffer()]) - inst.total_outbufs_len = 1 + inst.known_outbufs_len = 1 inst.write_soon(DummyData()) # we are testing that this method does not raise an OverflowError # (see https://github.com/Pylons/waitress/issues/47) - self.assertEqual(inst.total_outbufs_len, MAXINT+1) + self.assertEqual(inst.known_outbufs_len, MAXINT + 1) def test_writable_something_in_outbuf(self): inst, sock, map = self._makeOne() - inst.total_outbufs_len = 3 + inst.has_outbuf_data = True self.assertTrue(inst.writable()) def test_writable_nothing_in_outbuf(self): @@ -205,13 +206,13 @@ class TestHTTPChannel(unittest.TestCase): inst, sock, map = self._makeOne() wrote = inst.write_soon(b'') self.assertEqual(wrote, 0) - self.assertEqual(len(inst.outbufs[0]), 0) + self.assertEqual(inst.outbufs[0].remaining, 0) def test_write_soon_nonempty_byte(self): inst, sock, map = self._makeOne() wrote = inst.write_soon(b'a') self.assertEqual(wrote, 1) - self.assertEqual(len(inst.outbufs[0]), 1) + self.assertEqual(inst.outbufs[0].remaining, 1) def test_write_soon_filewrapper(self): from waitress.buffers import ReadOnlyFileBasedBuffer @@ -249,24 +250,24 @@ class TestHTTPChannel(unittest.TestCase): wrote = inst.write_soon(b'xyz') self.assertEqual(wrote, 3) self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b'') - self.assertEqual(inst.outbufs[1].get(), b'xyz') + self.assertEqual(inst.outbufs[0].read(), b'') + self.assertEqual(inst.outbufs[1].read(), b'xyz') def test_write_soon_waits_on_backpressure(self): inst, sock, map = self._makeOne() inst.adj.outbuf_high_watermark = 3 - inst.total_outbufs_len = 4 + inst.known_outbufs_len = 4 inst.current_outbuf_count = 4 class Lock(DummyLock): def wait(self): - inst.total_outbufs_len = 0 + inst.known_outbufs_len = 0 super(Lock, self).wait() inst.outbuf_lock = Lock() wrote = inst.write_soon(b'xyz') self.assertEqual(wrote, 3) self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b'') - self.assertEqual(inst.outbufs[1].get(), b'xyz') + self.assertEqual(inst.outbufs[0].read(), b'') + self.assertEqual(inst.outbufs[1].read(), b'xyz') self.assertTrue(inst.outbuf_lock.waited) def test_handle_write_notify_after_flush(self): @@ -306,7 +307,7 @@ class TestHTTPChannel(unittest.TestCase): def test__flush_some_full_outbuf_socket_returns_nonzero(self): inst, sock, map = self._makeOne() inst.outbufs[0].append(b'abc') - inst.total_outbufs_len += 3 + inst._scan_outbufs() result = inst._flush_some() self.assertEqual(result, True) @@ -314,7 +315,7 @@ class TestHTTPChannel(unittest.TestCase): inst, sock, map = self._makeOne() sock.send = lambda x: False inst.outbufs[0].append(b'abc') - inst.total_outbufs_len += 3 + inst._scan_outbufs() result = inst._flush_some() self.assertEqual(result, False) @@ -323,10 +324,10 @@ class TestHTTPChannel(unittest.TestCase): sock.send = lambda x: len(x) buffer = DummyBuffer(b'abc') inst.outbufs.append(buffer) - inst.total_outbufs_len += len(buffer) + inst._scan_outbufs() result = inst._flush_some() self.assertEqual(result, True) - self.assertEqual(buffer.skipped, 3) + self.assertEqual(buffer.total_read, 3) self.assertEqual(len(inst.outbufs), 1) self.assertEqual(inst.outbufs[0], buffer) @@ -335,14 +336,14 @@ class TestHTTPChannel(unittest.TestCase): sock.send = lambda x: len(x) buffer = DummyBuffer(b'abc') inst.outbufs.append(buffer) - inst.total_outbufs_len += len(buffer) + inst._scan_outbufs() inst.logger = DummyLogger() def doraise(): raise NotImplementedError inst.outbufs[0].close = doraise result = inst._flush_some() self.assertEqual(result, True) - self.assertEqual(buffer.skipped, 3) + self.assertEqual(buffer.total_read, 3) self.assertEqual(len(inst.outbufs), 1) self.assertEqual(inst.outbufs[0], buffer) self.assertEqual(len(inst.logger.exceptions), 1) @@ -350,14 +351,13 @@ class TestHTTPChannel(unittest.TestCase): def test__flush_some_outbuf_len_gt_sys_maxint(self): from waitress.compat import MAXINT class DummyHugeOutbuffer(object): + seekable = True def __init__(self): - self.length = MAXINT + 1 - def __len__(self): - return self.length - def get(self, numbytes): - self.length = 0 + self.remaining = MAXINT + 1 + def read(self, numbytes, *args): + self.remaining = 0 return b'123' - def skip(self, *args): pass + def rollback(self, *args): pass inst, sock, map = self._makeOne(outbuf=DummyHugeOutbuffer()) inst.send = lambda *arg: 0 result = inst._flush_some() @@ -481,7 +481,7 @@ class TestHTTPChannel(unittest.TestCase): inst.received(b'GET / HTTP/1.1\n\n') self.assertEqual(inst.request, preq) self.assertEqual(inst.server.tasks, []) - self.assertEqual(inst.outbufs[0].get(100), b'') + self.assertEqual(inst.outbufs[0].read(100), b'') def test_received_headers_finished_expect_continue_true(self): inst, sock, map = self._makeOne() @@ -749,25 +749,29 @@ class DummyLock(object): class DummyBuffer(object): closed = False + total_read = 0 + seekable = True def __init__(self, data, toraise=None): - self.data = data + self.buf = io.BytesIO(data) + self.remaining = len(data) self.toraise = toraise - def get(self, *arg): + def read(self, numbytes): if self.toraise: raise self.toraise - data = self.data - self.data = b'' + data = self.buf.read(numbytes) + self.remaining -= len(data) + self.total_read += len(data) return data - def skip(self, num, x): - self.skipped = num - - def __len__(self): - return len(self.data) + def rollback(self, numbytes): + self.buf.seek(-numbytes, 1) + self.remaining += numbytes + self.total_read -= numbytes def close(self): + self.remaining = 0 self.closed = True class DummyAdjustments(object): diff --git a/waitress/tests/test_functional.py b/waitress/tests/test_functional.py index 0571ce6..6160eeb 100644 --- a/waitress/tests/test_functional.py +++ b/waitress/tests/test_functional.py @@ -75,6 +75,7 @@ class SubprocessTests(object): self.sock.close() # This give us one FD back ... self.queue.close() + self.proc.join() def assertline(self, line, status, reason, version): v, s, r = (x.strip() for x in line.split(None, 2)) |