diff options
author | Kristján Valur Jónsson <sweskman@gmail.com> | 2022-10-30 11:39:41 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-30 13:39:41 +0200 |
commit | 16270e4bb5fcaebe8a8992fc2b2ccbd9d57b8c3e (patch) | |
tree | e83a05a63e58a7d172181a41d4206f58fa505839 /redis/asyncio/connection.py | |
parent | 842634e7fddeb32ba20aab0dacf557a958a4b00b (diff) | |
download | redis-py-16270e4bb5fcaebe8a8992fc2b2ccbd9d57b8c3e.tar.gz |
Remove the superflous SocketBuffer from asyncio PythonParser (#2418)
* Remove buffering from asyncio SocketBuffer and rely on on the underlying StreamReader
* Skip the use of SocketBuffer in PythonParser
* Remove SocketBuffer altogether
* Code cleanup
* Fix unittest mocking when SocketBuffer is gone
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 170 |
1 files changed, 37 insertions, 133 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bc0362e..b64bd12 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -2,7 +2,6 @@ import asyncio import copy import enum import inspect -import io import os import socket import ssl @@ -141,7 +140,7 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Except class BaseParser: """Plain Python parsing class""" - __slots__ = "_stream", "_buffer", "_read_size" + __slots__ = "_stream", "_read_size" EXCEPTION_CLASSES: ExceptionMappingT = { "ERR": { @@ -171,7 +170,6 @@ class BaseParser: def __init__(self, socket_read_size: int): self._stream: Optional[asyncio.StreamReader] = None - self._buffer: Optional[SocketBuffer] = None self._read_size = socket_read_size def __del__(self): @@ -206,127 +204,6 @@ class BaseParser: raise NotImplementedError() -class SocketBuffer: - """Async-friendly re-impl of redis-py's SocketBuffer. - - TODO: We're currently passing through two buffers, - the asyncio.StreamReader and this. I imagine we can reduce the layers here - while maintaining compliance with prior art. - """ - - def __init__( - self, - stream_reader: asyncio.StreamReader, - socket_read_size: int, - ): - self._stream: Optional[asyncio.StreamReader] = stream_reader - self.socket_read_size = socket_read_size - self._buffer: Optional[io.BytesIO] = io.BytesIO() - # number of bytes written to the buffer from the socket - self.bytes_written = 0 - # number of bytes read from the buffer - self.bytes_read = 0 - - @property - def length(self): - return self.bytes_written - self.bytes_read - - async def _read_from_socket(self, length: Optional[int] = None) -> bool: - buf = self._buffer - if buf is None or self._stream is None: - raise RedisError("Buffer is closed.") - buf.seek(self.bytes_written) - marker = 0 - - while True: - data = await self._stream.read(self.socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - self.bytes_written += data_length - marker += data_length - - if length is not None and length > marker: - continue - return True - - async def can_read_destructive(self) -> bool: - if self.length: - return True - try: - async with async_timeout.timeout(0): - return await self._read_from_socket() - except asyncio.TimeoutError: - return False - - async def read(self, length: int) -> bytes: - length = length + 2 # make sure to read the \r\n terminator - # make sure we've read enough data from the socket - if length > self.length: - await self._read_from_socket(length - self.length) - - if self._buffer is None: - raise RedisError("Buffer is closed.") - - self._buffer.seek(self.bytes_read) - data = self._buffer.read(length) - self.bytes_read += len(data) - - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() - - return data[:-2] - - async def readline(self) -> bytes: - buf = self._buffer - if buf is None: - raise RedisError("Buffer is closed.") - - buf.seek(self.bytes_read) - data = buf.readline() - while not data.endswith(SYM_CRLF): - # there's more data in the socket that we need - await self._read_from_socket() - buf.seek(self.bytes_read) - data = buf.readline() - - self.bytes_read += len(data) - - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() - - return data[:-2] - - def purge(self): - if self._buffer is None: - raise RedisError("Buffer is closed.") - - self._buffer.seek(0) - self._buffer.truncate() - self.bytes_written = 0 - self.bytes_read = 0 - - def close(self): - try: - self.purge() - self._buffer.close() - except Exception: - # issue #633 suggests the purge/close somehow raised a - # BadFileDescriptor error. Perhaps the client ran out of - # memory or something else? It's probably OK to ignore - # any error being raised from purge/close since we're - # removing the reference to the instance below. - pass - self._buffer = None - self._stream = None - - class PythonParser(BaseParser): """Plain Python parsing class""" @@ -342,27 +219,29 @@ class PythonParser(BaseParser): if self._stream is None: raise RedisError("Buffer is closed.") - self._buffer = SocketBuffer(self._stream, self._read_size) self.encoder = connection.encoder def on_disconnect(self): """Called when the stream disconnects""" if self._stream is not None: self._stream = None - if self._buffer is not None: - self._buffer.close() - self._buffer = None self.encoder = None - async def can_read_destructive(self): - return self._buffer and bool(await self._buffer.can_read_destructive()) + async def can_read_destructive(self) -> bool: + if self._stream is None: + raise RedisError("Buffer is closed.") + try: + async with async_timeout.timeout(0): + return await self._stream.read(1) + except asyncio.TimeoutError: + return False async def read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: - if not self._buffer or not self.encoder: + if not self._stream or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - raw = await self._buffer.readline() + raw = await self._readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any @@ -395,7 +274,7 @@ class PythonParser(BaseParser): length = int(response) if length == -1: return None - response = await self._buffer.read(length) + response = await self._read(length) # multi-bulk response elif byte == b"*": length = int(response) @@ -408,6 +287,31 @@ class PythonParser(BaseParser): response = self.encoder.decode(response) return response + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + if self._stream is None: + raise RedisError("Buffer is closed.") + try: + data = await self._stream.readexactly(length + 2) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + return data[:-2] + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + if self._stream is None: + raise RedisError("Buffer is closed.") + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + return data[:-2] + class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" |