summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2022-10-30 11:39:41 +0000
committerGitHub <noreply@github.com>2022-10-30 13:39:41 +0200
commit16270e4bb5fcaebe8a8992fc2b2ccbd9d57b8c3e (patch)
treee83a05a63e58a7d172181a41d4206f58fa505839
parent842634e7fddeb32ba20aab0dacf557a958a4b00b (diff)
downloadredis-py-16270e4bb5fcaebe8a8992fc2b2ccbd9d57b8c3e.tar.gz
Remove the superflous SocketBuffer from asyncio PythonParser (#2418)
* Remove buffering from asyncio SocketBuffer and rely on on the underlying StreamReader * Skip the use of SocketBuffer in PythonParser * Remove SocketBuffer altogether * Code cleanup * Fix unittest mocking when SocketBuffer is gone
-rw-r--r--redis/asyncio/connection.py170
-rw-r--r--tests/test_asyncio/test_connection.py9
2 files changed, 42 insertions, 137 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index bc0362e..b64bd12 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -2,7 +2,6 @@ import asyncio
import copy
import enum
import inspect
-import io
import os
import socket
import ssl
@@ -141,7 +140,7 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Except
class BaseParser:
"""Plain Python parsing class"""
- __slots__ = "_stream", "_buffer", "_read_size"
+ __slots__ = "_stream", "_read_size"
EXCEPTION_CLASSES: ExceptionMappingT = {
"ERR": {
@@ -171,7 +170,6 @@ class BaseParser:
def __init__(self, socket_read_size: int):
self._stream: Optional[asyncio.StreamReader] = None
- self._buffer: Optional[SocketBuffer] = None
self._read_size = socket_read_size
def __del__(self):
@@ -206,127 +204,6 @@ class BaseParser:
raise NotImplementedError()
-class SocketBuffer:
- """Async-friendly re-impl of redis-py's SocketBuffer.
-
- TODO: We're currently passing through two buffers,
- the asyncio.StreamReader and this. I imagine we can reduce the layers here
- while maintaining compliance with prior art.
- """
-
- def __init__(
- self,
- stream_reader: asyncio.StreamReader,
- socket_read_size: int,
- ):
- self._stream: Optional[asyncio.StreamReader] = stream_reader
- self.socket_read_size = socket_read_size
- self._buffer: Optional[io.BytesIO] = io.BytesIO()
- # number of bytes written to the buffer from the socket
- self.bytes_written = 0
- # number of bytes read from the buffer
- self.bytes_read = 0
-
- @property
- def length(self):
- return self.bytes_written - self.bytes_read
-
- async def _read_from_socket(self, length: Optional[int] = None) -> bool:
- buf = self._buffer
- if buf is None or self._stream is None:
- raise RedisError("Buffer is closed.")
- buf.seek(self.bytes_written)
- marker = 0
-
- while True:
- data = await self._stream.read(self.socket_read_size)
- # an empty string indicates the server shutdown the socket
- if isinstance(data, bytes) and len(data) == 0:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- buf.write(data)
- data_length = len(data)
- self.bytes_written += data_length
- marker += data_length
-
- if length is not None and length > marker:
- continue
- return True
-
- async def can_read_destructive(self) -> bool:
- if self.length:
- return True
- try:
- async with async_timeout.timeout(0):
- return await self._read_from_socket()
- except asyncio.TimeoutError:
- return False
-
- async def read(self, length: int) -> bytes:
- length = length + 2 # make sure to read the \r\n terminator
- # make sure we've read enough data from the socket
- if length > self.length:
- await self._read_from_socket(length - self.length)
-
- if self._buffer is None:
- raise RedisError("Buffer is closed.")
-
- self._buffer.seek(self.bytes_read)
- data = self._buffer.read(length)
- self.bytes_read += len(data)
-
- # purge the buffer when we've consumed it all so it doesn't
- # grow forever
- if self.bytes_read == self.bytes_written:
- self.purge()
-
- return data[:-2]
-
- async def readline(self) -> bytes:
- buf = self._buffer
- if buf is None:
- raise RedisError("Buffer is closed.")
-
- buf.seek(self.bytes_read)
- data = buf.readline()
- while not data.endswith(SYM_CRLF):
- # there's more data in the socket that we need
- await self._read_from_socket()
- buf.seek(self.bytes_read)
- data = buf.readline()
-
- self.bytes_read += len(data)
-
- # purge the buffer when we've consumed it all so it doesn't
- # grow forever
- if self.bytes_read == self.bytes_written:
- self.purge()
-
- return data[:-2]
-
- def purge(self):
- if self._buffer is None:
- raise RedisError("Buffer is closed.")
-
- self._buffer.seek(0)
- self._buffer.truncate()
- self.bytes_written = 0
- self.bytes_read = 0
-
- def close(self):
- try:
- self.purge()
- self._buffer.close()
- except Exception:
- # issue #633 suggests the purge/close somehow raised a
- # BadFileDescriptor error. Perhaps the client ran out of
- # memory or something else? It's probably OK to ignore
- # any error being raised from purge/close since we're
- # removing the reference to the instance below.
- pass
- self._buffer = None
- self._stream = None
-
-
class PythonParser(BaseParser):
"""Plain Python parsing class"""
@@ -342,27 +219,29 @@ class PythonParser(BaseParser):
if self._stream is None:
raise RedisError("Buffer is closed.")
- self._buffer = SocketBuffer(self._stream, self._read_size)
self.encoder = connection.encoder
def on_disconnect(self):
"""Called when the stream disconnects"""
if self._stream is not None:
self._stream = None
- if self._buffer is not None:
- self._buffer.close()
- self._buffer = None
self.encoder = None
- async def can_read_destructive(self):
- return self._buffer and bool(await self._buffer.can_read_destructive())
+ async def can_read_destructive(self) -> bool:
+ if self._stream is None:
+ raise RedisError("Buffer is closed.")
+ try:
+ async with async_timeout.timeout(0):
+ return await self._stream.read(1)
+ except asyncio.TimeoutError:
+ return False
async def read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
- if not self._buffer or not self.encoder:
+ if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- raw = await self._buffer.readline()
+ raw = await self._readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
response: Any
@@ -395,7 +274,7 @@ class PythonParser(BaseParser):
length = int(response)
if length == -1:
return None
- response = await self._buffer.read(length)
+ response = await self._read(length)
# multi-bulk response
elif byte == b"*":
length = int(response)
@@ -408,6 +287,31 @@ class PythonParser(BaseParser):
response = self.encoder.decode(response)
return response
+ async def _read(self, length: int) -> bytes:
+ """
+ Read `length` bytes of data. These are assumed to be followed
+ by a '\r\n' terminator which is subsequently discarded.
+ """
+ if self._stream is None:
+ raise RedisError("Buffer is closed.")
+ try:
+ data = await self._stream.readexactly(length + 2)
+ except asyncio.IncompleteReadError as error:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
+ return data[:-2]
+
+ async def _readline(self) -> bytes:
+ """
+ read an unknown number of bytes up to the next '\r\n'
+ line separator, which is discarded.
+ """
+ if self._stream is None:
+ raise RedisError("Buffer is closed.")
+ data = await self._stream.readline()
+ if not data.endswith(b"\r\n"):
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ return data[:-2]
+
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index 674a1b9..6bf0034 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -13,22 +13,23 @@ from redis.asyncio.connection import (
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
-from redis.utils import HIREDIS_AVAILABLE
from tests.conftest import skip_if_server_version_lt
from .compat import mock
@pytest.mark.onlynoncluster
-@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)
raw = b"x"
- readline_mock = mock.AsyncMock(return_value=raw)
parser: "PythonParser" = r.connection._parser
- with mock.patch.object(parser._buffer, "readline", readline_mock):
+ if not isinstance(parser, PythonParser):
+ pytest.skip("PythonParser only")
+ stream_mock = mock.Mock(parser._stream)
+ stream_mock.readline.return_value = raw + b"\r\n"
+ with mock.patch.object(parser, "_stream", stream_mock):
with pytest.raises(InvalidResponse) as cm:
await parser.read_response()
assert str(cm.value) == f"Protocol Error: {raw!r}"