summaryrefslogtreecommitdiff
path: root/redis/asyncio/connection.py
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2023-01-05 12:42:37 +0000
committerGitHub <noreply@github.com>2023-01-05 14:42:37 +0200
commita9ef0c5d0080bd14e2f189d7f31d83e758346a8d (patch)
treebf22a1be3fccf5dbc5fcd03bf7ee2345d9b07487 /redis/asyncio/connection.py
parenta94772848db87bfc2c3cee20d8ca8b257fc37466 (diff)
downloadredis-py-a9ef0c5d0080bd14e2f189d7f31d83e758346a8d.tar.gz
Make PythonParser resumable (#2510)
* PythonParser is now resumable if _stream IO is interrupted * Add test for parse resumability * Clear PythonParser state when connection or parsing errors occur. * disable test for cluster mode. * Perform "closed" check in a single place. * Update tests * Simplify code. * Remove reduntant test, EOF is detected inside _readline() * Make syncronous PythonParser restartable on error, same as HiredisParser Fix sync PythonParser * Add CHANGES * isort * Move MockStream and MockSocket into their own files
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r--redis/asyncio/connection.py69
1 files changed, 51 insertions, 18 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index d031f41..2c75d4f 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -208,11 +208,18 @@ class BaseParser:
class PythonParser(BaseParser):
"""Plain Python parsing class"""
- __slots__ = BaseParser.__slots__ + ("encoder",)
+ __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
+ self._buffer = b""
+ self._chunks = []
+ self._pos = 0
+
+ def _clear(self):
+ self._buffer = b""
+ self._chunks.clear()
def on_connect(self, connection: "Connection"):
"""Called when the stream connects"""
@@ -227,8 +234,11 @@ class PythonParser(BaseParser):
if self._stream is not None:
self._stream = None
self.encoder = None
+ self._clear()
async def can_read_destructive(self) -> bool:
+ if self._buffer:
+ return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
@@ -237,14 +247,23 @@ class PythonParser(BaseParser):
except asyncio.TimeoutError:
return False
- async def read_response(
+ async def read_response(self, disable_decoding: bool = False):
+ if self._chunks:
+ # augment parsing buffer with previously read data
+ self._buffer += b"".join(self._chunks)
+ self._chunks.clear()
+ self._pos = 0
+ response = await self._read_response(disable_decoding=disable_decoding)
+ # Successfully parsing a response allows us to clear our parsing buffer
+ self._clear()
+ return response
+
+ async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
- if not raw:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
response: Any
byte, response = raw[:1], raw[1:]
@@ -258,6 +277,7 @@ class PythonParser(BaseParser):
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
+ self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
@@ -282,7 +302,7 @@ class PythonParser(BaseParser):
if length == -1:
return None
response = [
- (await self.read_response(disable_decoding)) for _ in range(length)
+ (await self._read_response(disable_decoding)) for _ in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
@@ -293,25 +313,38 @@ class PythonParser(BaseParser):
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
- if self._stream is None:
- raise RedisError("Buffer is closed.")
- try:
- data = await self._stream.readexactly(length + 2)
- except asyncio.IncompleteReadError as error:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
- return data[:-2]
+ want = length + 2
+ end = self._pos + want
+ if len(self._buffer) >= end:
+ result = self._buffer[self._pos : end - 2]
+ else:
+ tail = self._buffer[self._pos :]
+ try:
+ data = await self._stream.readexactly(want - len(tail))
+ except asyncio.IncompleteReadError as error:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
+ result = (tail + data)[:-2]
+ self._chunks.append(data)
+ self._pos += want
+ return result
async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
- if self._stream is None:
- raise RedisError("Buffer is closed.")
- data = await self._stream.readline()
- if not data.endswith(b"\r\n"):
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- return data[:-2]
+ found = self._buffer.find(b"\r\n", self._pos)
+ if found >= 0:
+ result = self._buffer[self._pos : found]
+ else:
+ tail = self._buffer[self._pos :]
+ data = await self._stream.readline()
+ if not data.endswith(b"\r\n"):
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ result = (tail + data)[:-2]
+ self._chunks.append(data)
+ self._pos += len(result) + 2
+ return result
class HiredisParser(BaseParser):