diff options
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 69 |
1 files changed, 51 insertions, 18 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d031f41..2c75d4f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -208,11 +208,18 @@ class BaseParser: class PythonParser(BaseParser): """Plain Python parsing class""" - __slots__ = BaseParser.__slots__ + ("encoder",) + __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") def __init__(self, socket_read_size: int): super().__init__(socket_read_size) self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() def on_connect(self, connection: "Connection"): """Called when the stream connects""" @@ -227,8 +234,11 @@ class PythonParser(BaseParser): if self._stream is not None: self._stream = None self.encoder = None + self._clear() async def can_read_destructive(self) -> bool: + if self._buffer: + return True if self._stream is None: raise RedisError("Buffer is closed.") try: @@ -237,14 +247,23 @@ class PythonParser(BaseParser): except asyncio.TimeoutError: return False - async def read_response( + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: if not self._stream or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any byte, response = raw[:1], raw[1:] @@ -258,6 +277,7 @@ class PythonParser(BaseParser): # if the error is a ConnectionError, raise immediately so the user # is notified if isinstance(error, ConnectionError): + self._clear() # Successful parse raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() @@ -282,7 +302,7 @@ class PythonParser(BaseParser): if length == -1: return None response = [ - (await self.read_response(disable_decoding)) for _ in range(length) + (await self._read_response(disable_decoding)) for _ in range(length) ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) @@ -293,25 +313,38 @@ class PythonParser(BaseParser): Read `length` bytes of data. These are assumed to be followed by a '\r\n' terminator which is subsequently discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - data = await self._stream.readexactly(length + 2) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - return data[:-2] + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result async def _readline(self) -> bytes: """ read an unknown number of bytes up to the next '\r\n' line separator, which is discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - return data[:-2] + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result class HiredisParser(BaseParser): |