diff options
-rw-r--r-- | CHANGES | 1 | ||||
-rw-r--r-- | redis/asyncio/connection.py | 69 | ||||
-rwxr-xr-x | redis/connection.py | 58 | ||||
-rw-r--r-- | tests/mocks.py | 41 | ||||
-rw-r--r-- | tests/test_asyncio/mocks.py | 51 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection.py | 48 | ||||
-rw-r--r-- | tests/test_connection.py | 43 |
7 files changed, 269 insertions, 42 deletions
@@ -1,3 +1,4 @@ + * Make PythonParser resumable in case of error (#2510) * Add `timeout=None` in `SentinelConnectionManager.read_response` * Documentation fix: password protected socket connection (#2374) * Allow `timeout=None` in `PubSub.get_message()` to wait forever diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d031f41..2c75d4f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -208,11 +208,18 @@ class BaseParser: class PythonParser(BaseParser): """Plain Python parsing class""" - __slots__ = BaseParser.__slots__ + ("encoder",) + __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") def __init__(self, socket_read_size: int): super().__init__(socket_read_size) self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() def on_connect(self, connection: "Connection"): """Called when the stream connects""" @@ -227,8 +234,11 @@ class PythonParser(BaseParser): if self._stream is not None: self._stream = None self.encoder = None + self._clear() async def can_read_destructive(self) -> bool: + if self._buffer: + return True if self._stream is None: raise RedisError("Buffer is closed.") try: @@ -237,14 +247,23 @@ class PythonParser(BaseParser): except asyncio.TimeoutError: return False - async def read_response( + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: if not self._stream or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any byte, response = raw[:1], raw[1:] @@ -258,6 +277,7 @@ class PythonParser(BaseParser): # if the error is a ConnectionError, raise immediately so the user # is notified if isinstance(error, ConnectionError): + self._clear() # Successful parse raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() @@ -282,7 +302,7 @@ class PythonParser(BaseParser): if length == -1: return None response = [ - (await self.read_response(disable_decoding)) for _ in range(length) + (await self._read_response(disable_decoding)) for _ in range(length) ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) @@ -293,25 +313,38 @@ class PythonParser(BaseParser): Read `length` bytes of data. These are assumed to be followed by a '\r\n' terminator which is subsequently discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - data = await self._stream.readexactly(length + 2) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - return data[:-2] + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result async def _readline(self) -> bytes: """ read an unknown number of bytes up to the next '\r\n' line separator, which is discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - return data[:-2] + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result class HiredisParser(BaseParser): diff --git a/redis/connection.py b/redis/connection.py index b810fc5..126ea5d 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -232,12 +232,6 @@ class SocketBuffer: self._buffer.seek(self.bytes_read) data = self._buffer.read(length) self.bytes_read += len(data) - - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() - return data[:-2] def readline(self): @@ -251,23 +245,44 @@ class SocketBuffer: data = buf.readline() self.bytes_read += len(data) + return data[:-2] - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() + def get_pos(self): + """ + Get current read position + """ + return self.bytes_read - return data[:-2] + def rewind(self, pos): + """ + Rewind the buffer to a specific position, to re-start reading + """ + self.bytes_read = pos def purge(self): - self._buffer.seek(0) - self._buffer.truncate() - self.bytes_written = 0 + """ + After a successful read, purge the read part of buffer + """ + unread = self.bytes_written - self.bytes_read + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self.bytes_written = unread self.bytes_read = 0 + self._buffer.seek(0) def close(self): try: - self.purge() + self.bytes_written = self.bytes_read = 0 self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a @@ -315,6 +330,17 @@ class PythonParser(BaseParser): return self._buffer and self._buffer.can_read(timeout) def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): raw = self._buffer.readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) @@ -355,7 +381,7 @@ class PythonParser(BaseParser): if length == -1: return None response = [ - self.read_response(disable_decoding=disable_decoding) + self._read_response(disable_decoding=disable_decoding) for i in range(length) ] if isinstance(response, bytes) and disable_decoding is False: diff --git a/tests/mocks.py b/tests/mocks.py new file mode 100644 index 0000000..d7d450e --- /dev/null +++ b/tests/mocks.py @@ -0,0 +1,41 @@ +# Various mocks for testing + + +class MockSocket: + """ + A class simulating an readable socket, optionally raising a + special exception every other read. + """ + + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): + self.data = data + self.counter = 0 + self.pos = 0 + self.interrupt_every = interrupt_every + + def tick(self): + self.counter += 1 + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() + + def recv(self, bufsize): + self.tick() + bufsize = min(5, bufsize) # truncate the read size + result = self.data[self.pos : self.pos + bufsize] + self.pos += len(result) + return result + + def recv_into(self, buffer, nbytes=0, flags=0): + self.tick() + if nbytes == 0: + nbytes = len(buffer) + nbytes = min(5, nbytes) # truncate the read size + result = self.data[self.pos : self.pos + nbytes] + self.pos += len(result) + buffer[: len(result)] = result + return len(result) diff --git a/tests/test_asyncio/mocks.py b/tests/test_asyncio/mocks.py new file mode 100644 index 0000000..89bd9c0 --- /dev/null +++ b/tests/test_asyncio/mocks.py @@ -0,0 +1,51 @@ +import asyncio + +# Helper Mocking classes for the tests. + + +class MockStream: + """ + A class simulating an asyncio input buffer, optionally raising a + special exception every other read. + """ + + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): + self.data = data + self.counter = 0 + self.pos = 0 + self.interrupt_every = interrupt_every + + def tick(self): + self.counter += 1 + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() + + async def read(self, want): + self.tick() + want = 5 + result = self.data[self.pos : self.pos + want] + self.pos += len(result) + return result + + async def readline(self): + self.tick() + find = self.data.find(b"\n", self.pos) + if find >= 0: + result = self.data[self.pos : find + 1] + else: + result = self.data[self.pos :] + self.pos += len(result) + return result + + async def readexactly(self, length): + self.tick() + result = self.data[self.pos : self.pos + length] + if len(result) < length: + raise asyncio.IncompleteReadError(result, None) + self.pos += len(result) + return result diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6bf0034..bf59dbe 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -5,7 +5,9 @@ from unittest.mock import patch import pytest +import redis from redis.asyncio.connection import ( + BaseParser, Connection, PythonParser, UnixDomainSocketConnection, @@ -16,6 +18,7 @@ from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from tests.conftest import skip_if_server_version_lt from .compat import mock +from .mocks import MockStream @pytest.mark.onlynoncluster @@ -23,16 +26,19 @@ async def test_invalid_response(create_redis): r = await create_redis(single_connection_client=True) raw = b"x" + fake_stream = MockStream(raw + b"\r\n") - parser: "PythonParser" = r.connection._parser - if not isinstance(parser, PythonParser): - pytest.skip("PythonParser only") - stream_mock = mock.Mock(parser._stream) - stream_mock.readline.return_value = raw + b"\r\n" - with mock.patch.object(parser, "_stream", stream_mock): + parser: BaseParser = r.connection._parser + with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - assert str(cm.value) == f"Protocol Error: {raw!r}" + if isinstance(parser, PythonParser): + assert str(cm.value) == f"Protocol Error: {raw!r}" + else: + assert ( + str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte' + ) + await r.connection.disconnect() @skip_if_server_version_lt("4.0.0") @@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry(): await conn.connect() assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" + + +@pytest.mark.onlynoncluster +async def test_connection_parse_response_resume(r: redis.Redis): + """ + This test verifies that the Connection parser, + be that PythonParser or HiredisParser, + can be interrupted at IO time and then resume parsing. + """ + conn = Connection(**r.connection_pool.connection_kwargs) + await conn.connect() + message = ( + b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n" + b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" + ) + + conn._parser._stream = MockStream(message, interrupt_every=2) + for i in range(100): + try: + response = await conn.read_response() + break + except MockStream.TestError: + pass + + else: + pytest.fail("didn't receive a response") + assert response + assert i > 0 diff --git a/tests/test_connection.py b/tests/test_connection.py index d9251c3..e0b53cd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,13 +5,15 @@ from unittest.mock import patch import pytest +import redis from redis.backoff import NoBackoff -from redis.connection import Connection +from redis.connection import Connection, HiredisParser, PythonParser from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE from .conftest import skip_if_server_version_lt +from .mocks import MockSocket @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -122,3 +124,42 @@ class TestConnection: assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" self.clear(conn) + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] +) +def test_connection_parse_response_resume(r: redis.Redis, parser_class): + """ + This test verifies that the Connection parser, + be that PythonParser or HiredisParser, + can be interrupted at IO time and then resume parsing. + """ + if parser_class is HiredisParser and not HIREDIS_AVAILABLE: + pytest.skip("Hiredis not available)") + args = dict(r.connection_pool.connection_kwargs) + args["parser_class"] = parser_class + conn = Connection(**args) + conn.connect() + message = ( + b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n" + b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" + ) + mock_socket = MockSocket(message, interrupt_every=2) + + if isinstance(conn._parser, PythonParser): + conn._parser._buffer._sock = mock_socket + else: + conn._parser._sock = mock_socket + for i in range(100): + try: + response = conn.read_response() + break + except MockSocket.TestError: + pass + + else: + pytest.fail("didn't receive a response") + assert response + assert i > 0 |