diff options
author | Kristján Valur Jónsson <sweskman@gmail.com> | 2023-01-05 12:42:37 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-05 14:42:37 +0200 |
commit | a9ef0c5d0080bd14e2f189d7f31d83e758346a8d (patch) | |
tree | bf22a1be3fccf5dbc5fcd03bf7ee2345d9b07487 /redis/asyncio/connection.py | |
parent | a94772848db87bfc2c3cee20d8ca8b257fc37466 (diff) | |
download | redis-py-a9ef0c5d0080bd14e2f189d7f31d83e758346a8d.tar.gz |
Make PythonParser resumable (#2510)
* PythonParser is now resumable if _stream IO is interrupted
* Add test for parse resumability
* Clear PythonParser state when connection or parsing errors occur.
* disable test for cluster mode.
* Perform "closed" check in a single place.
* Update tests
* Simplify code.
* Remove reduntant test, EOF is detected inside _readline()
* Make syncronous PythonParser restartable on error, same as HiredisParser
Fix sync PythonParser
* Add CHANGES
* isort
* Move MockStream and MockSocket into their own files
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 69 |
1 files changed, 51 insertions, 18 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d031f41..2c75d4f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -208,11 +208,18 @@ class BaseParser: class PythonParser(BaseParser): """Plain Python parsing class""" - __slots__ = BaseParser.__slots__ + ("encoder",) + __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") def __init__(self, socket_read_size: int): super().__init__(socket_read_size) self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() def on_connect(self, connection: "Connection"): """Called when the stream connects""" @@ -227,8 +234,11 @@ class PythonParser(BaseParser): if self._stream is not None: self._stream = None self.encoder = None + self._clear() async def can_read_destructive(self) -> bool: + if self._buffer: + return True if self._stream is None: raise RedisError("Buffer is closed.") try: @@ -237,14 +247,23 @@ class PythonParser(BaseParser): except asyncio.TimeoutError: return False - async def read_response( + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: if not self._stream or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any byte, response = raw[:1], raw[1:] @@ -258,6 +277,7 @@ class PythonParser(BaseParser): # if the error is a ConnectionError, raise immediately so the user # is notified if isinstance(error, ConnectionError): + self._clear() # Successful parse raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() @@ -282,7 +302,7 @@ class PythonParser(BaseParser): if length == -1: return None response = [ - (await self.read_response(disable_decoding)) for _ in range(length) + (await self._read_response(disable_decoding)) for _ in range(length) ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) @@ -293,25 +313,38 @@ class PythonParser(BaseParser): Read `length` bytes of data. These are assumed to be followed by a '\r\n' terminator which is subsequently discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - data = await self._stream.readexactly(length + 2) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - return data[:-2] + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result async def _readline(self) -> bytes: """ read an unknown number of bytes up to the next '\r\n' line separator, which is discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - return data[:-2] + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result class HiredisParser(BaseParser): |