diff options
author | Kristján Valur Jónsson <sweskman@gmail.com> | 2022-09-29 13:37:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-29 16:37:55 +0300 |
commit | b0883b791f95a595fae70bcedf3ad0f73c00e258 (patch) | |
tree | 6bfeb975aa6b9ec6325e90246230d01172beb1a6 /redis/asyncio/connection.py | |
parent | cdbc662adcd303d2525f3ace70531aa37a755652 (diff) | |
download | redis-py-b0883b791f95a595fae70bcedf3ad0f73c00e258.tar.gz |
Simplify async timeouts and allowing `timeout=None` in `PubSub.get_message()` to wait forever (#2295)v4.4.0rc2
* Avoid an extra "can_read" call and use timeout directly.
* Remove low-level read timeouts from the Parser, now handled in the Connection
* Allow pubsub.get_message(time=None) to block.
* update Changes
* increase test timeout for robustness
* expand with statement to avoid invoking null context managers.
remove nullcontext
* Remove unused import
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 138 |
1 files changed, 47 insertions, 91 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 53b41af..c8834c9 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1,7 +1,6 @@ import asyncio import copy import enum -import errno import inspect import io import os @@ -55,16 +54,6 @@ hiredis = None if HIREDIS_AVAILABLE: import hiredis -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { - BlockingIOError: errno.EWOULDBLOCK, - ssl.SSLWantReadError: 2, - ssl.SSLWantWriteError: 2, - ssl.SSLError: 2, -} - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) - - SYM_STAR = b"*" SYM_DOLLAR = b"$" SYM_CRLF = b"\r\n" @@ -229,11 +218,9 @@ class SocketBuffer: self, stream_reader: asyncio.StreamReader, socket_read_size: int, - socket_timeout: Optional[float], ): self._stream: Optional[asyncio.StreamReader] = stream_reader self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout self._buffer: Optional[io.BytesIO] = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 @@ -244,52 +231,35 @@ class SocketBuffer: def length(self): return self.bytes_written - self.bytes_read - async def _read_from_socket( - self, - length: Optional[int] = None, - timeout: Union[float, None, _Sentinel] = SENTINEL, - raise_on_timeout: bool = True, - ) -> bool: + async def _read_from_socket(self, length: Optional[int] = None) -> bool: buf = self._buffer if buf is None or self._stream is None: raise RedisError("Buffer is closed.") buf.seek(self.bytes_written) marker = 0 - timeout = timeout if timeout is not SENTINEL else self.socket_timeout - try: - while True: - async with async_timeout.timeout(timeout): - data = await self._stream.read(self.socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - self.bytes_written += data_length - marker += data_length - - if length is not None and length > marker: - continue - return True - except (socket.timeout, asyncio.TimeoutError): - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") + while True: + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True async def can_read_destructive(self) -> bool: - return bool(self.length) or await self._read_from_socket( - timeout=0, raise_on_timeout=False - ) + if self.length: + return True + try: + async with async_timeout.timeout(0): + return await self._read_from_socket() + except asyncio.TimeoutError: + return False async def read(self, length: int) -> bytes: length = length + 2 # make sure to read the \r\n terminator @@ -372,9 +342,7 @@ class PythonParser(BaseParser): if self._stream is None: raise RedisError("Buffer is closed.") - self._buffer = SocketBuffer( - self._stream, self._read_size, connection.socket_timeout - ) + self._buffer = SocketBuffer(self._stream, self._read_size) self.encoder = connection.encoder def on_disconnect(self): @@ -444,14 +412,13 @@ class PythonParser(BaseParser): class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout") + __slots__ = BaseParser.__slots__ + ("_reader",) def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None - self._socket_timeout: Optional[float] = None def on_connect(self, connection: "Connection"): self._stream = connection._reader @@ -464,7 +431,6 @@ class HiredisParser(BaseParser): kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) - self._socket_timeout = connection.socket_timeout def on_disconnect(self): self._stream = None @@ -475,39 +441,20 @@ class HiredisParser(BaseParser): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._reader.gets(): return True - return await self.read_from_socket(timeout=0, raise_on_timeout=False) - - async def read_from_socket( - self, - timeout: Union[float, None, _Sentinel] = SENTINEL, - raise_on_timeout: bool = True, - ): - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - if timeout is None: - buffer = await self._stream.read(self._read_size) - else: - async with async_timeout.timeout(timeout): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except (socket.timeout, asyncio.TimeoutError): - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") from None + async with async_timeout.timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True async def read_response( self, disable_decoding: bool = False @@ -922,11 +869,16 @@ class Connection: f"Error while reading from {self.host}:{self.port}: {e.args}" ) - async def read_response(self, disable_decoding: bool = False): + async def read_response( + self, + disable_decoding: bool = False, + timeout: Optional[float] = None, + ): """Read the response from a previously sent command""" + read_timeout = timeout if timeout is not None else self.socket_timeout try: - if self.socket_timeout: - async with async_timeout.timeout(self.socket_timeout): + if read_timeout is not None: + async with async_timeout.timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) @@ -935,6 +887,10 @@ class Connection: disable_decoding=disable_decoding ) except asyncio.TimeoutError: + if timeout is not None: + # user requested timeout, return None + return None + # it was a self.socket_timeout error. await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: |